MLIR  22.0.0git
NVGPUTransformOps.cpp
Go to the documentation of this file.
1 //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Value.h"
32 #include "llvm/ADT/ArrayRef.h"
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 using namespace mlir::nvgpu;
37 using namespace mlir::NVVM;
38 using namespace mlir::transform;
39 
40 #define DEBUG_TYPE "nvgpu-transforms"
41 
42 //===----------------------------------------------------------------------===//
43 // Apply...ConversionPatternsOp
44 //===----------------------------------------------------------------------===//
45 
46 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
47  TypeConverter &typeConverter, RewritePatternSet &patterns) {
48  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
49  /// device-side async tokens cannot be materialized in nvvm. We just
50  /// convert them to a dummy i32 type in order to easily drop them during
51  /// conversion.
53  llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
54  switch (space) {
55  case gpu::AddressSpace::Global:
56  return static_cast<unsigned>(
58  case gpu::AddressSpace::Workgroup:
59  return static_cast<unsigned>(
61  case gpu::AddressSpace::Private:
62  return 0;
63  }
64  llvm_unreachable("unknown address space enum value");
65  return 0;
66  });
67  llvmTypeConverter.addConversion(
68  [&](nvgpu::DeviceAsyncTokenType type) -> Type {
69  return llvmTypeConverter.convertType(
70  IntegerType::get(type.getContext(), 32));
71  });
72  llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
73  return llvmTypeConverter.convertType(
74  IntegerType::get(type.getContext(), 64));
75  });
76  llvmTypeConverter.addConversion(
77  [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
78  Type elemType = type.getFragmented().getElementType();
79  int64_t sizeM = type.getFragmented().getDimSize(0);
80  int64_t sizeN = type.getFragmented().getDimSize(1);
81 
82  unsigned numMembers;
83  if (elemType.isF32() || elemType.isInteger(32))
84  numMembers = sizeN / 2;
85  else if (elemType.isF16())
86  numMembers = sizeN / 4;
87  else
88  llvm_unreachable("unsupported type for warpgroup accumulator");
89 
90  SmallVector<Type> innerStructBody;
91  for (unsigned i = 0; i < numMembers; i++)
92  innerStructBody.push_back(elemType);
93  auto innerStructType = LLVM::LLVMStructType::getLiteral(
94  type.getContext(), innerStructBody);
95 
96  SmallVector<Type> structBody;
97  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
98  structBody.push_back(innerStructType);
99 
100  auto convertedType =
101  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
102  return llvmTypeConverter.convertType(convertedType);
103  });
104  llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
105  return llvmTypeConverter.convertType(
106  getMBarrierMemrefType(type.getContext(), type));
107  });
108  llvmTypeConverter.addConversion(
109  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
110  return llvmTypeConverter.convertType(
111  IntegerType::get(type.getContext(), 64));
112  });
113  llvmTypeConverter.addConversion(
114  [&](nvgpu::TensorMapDescriptorType type) -> Type {
115  return LLVM::LLVMPointerType::get(type.getContext());
116  });
118 }
119 
120 LogicalResult
121 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
122  transform::TypeConverterBuilderOpInterface builder) {
123  if (builder.getTypeConverterType() != "LLVMTypeConverter")
124  return emitOpError("expected LLVMTypeConverter");
125  return success();
126 }
127 
128 //===---------------------------------------------------------------------===//
129 // CreateAsyncGroupsOp
130 //===---------------------------------------------------------------------===//
131 
132 void transform::CreateAsyncGroupsOp::getEffects(
134  transform::consumesHandle(getTargetMutable(), effects);
135  transform::producesHandle(getOperation()->getOpResults(), effects);
137 }
138 
139 DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
140  TransformRewriter &rewriter, Operation *target,
141  ApplyToEachResultList &results, TransformState &state) {
142  nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
143  results.push_back(target);
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // PipelineSharedMemoryCopiesOp
149 //===----------------------------------------------------------------------===//
150 
151 /// Returns true if the given type has the default memory space.
153  return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
154 }
155 
156 /// Returns true if the given type has the shared (workgroup) memory space.
158  auto space =
159  dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
160  return space &&
161  space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
162 }
163 
164 /// Returns the value produced by a load from the default memory space. Returns
165 /// null if the operation is not such a load.
167  // TODO: consider an interface or leveraging the memory effects interface.
168  auto load = dyn_cast<vector::TransferReadOp>(op);
169  if (!load)
170  return nullptr;
171 
172  auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
173  if (!loadType || !hasDefaultMemorySpace(loadType))
174  return nullptr;
175  return load;
176 }
177 
178 /// Returns true if the operation is storing the given value into shared memory.
179 static bool isStoreToShared(Operation *op, Value v) {
180  // TOD: consider an interface or leveraging the memory effects interface.
181  auto store = dyn_cast<vector::TransferWriteOp>(op);
182  if (!store || store.getVector() != v)
183  return false;
184 
185  auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
186  return storeType || hasSharedMemorySpace(storeType);
187 }
188 
189 /// Returns true if the operation is a load from the default memory space the
190 /// result of which is only stored into the shared memory space.
192  Value loaded = getValueLoadedFromGlobal(op);
193  if (!loaded || !loaded.hasOneUse())
194  return false;
195 
196  return isStoreToShared(*loaded.getUsers().begin(), loaded);
197 }
198 
199 /// Populate `ops` with the set of operations that belong to the stage 0 of the
200 /// pipelined version of the given loop when pipelining copies to shared memory.
201 /// Specifically, this collects:
202 ///
203 /// 1. all loads from global memory, both sync and async;
204 /// 2. the barriers for async loads.
205 ///
206 /// In particular, barriers are omitted if they do not dominate at least one
207 /// async load for which there is not yet a barrier.
208 static LogicalResult
209 collectStage0PipeliningOps(scf::ForOp forOp,
211 
213  for (Operation &op : *forOp.getBody()) {
214  // Bail on nested ops for now.
215  if (op.getNumRegions() > 0)
216  return failure();
217 
218  if (isa<gpu::BarrierOp>(op)) {
219  barriers.insert(&op);
220  continue;
221  }
222 
223  if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
224  ops.insert(&op);
225  ops.insert(std::make_move_iterator(barriers.begin()),
226  std::make_move_iterator(barriers.end()));
227  assert(barriers.empty() &&
228  "expected to have moved the barriers into another set");
229  continue;
230  }
231 
233  ops.insert(&op);
234  continue;
235  }
236  }
237 
238  return success();
239 }
240 
241 /// Hook for the loop pipeliner that sets the "num groups in flight" attribute
242 /// of async wait operations corresponding to pipelined shared memory copies.
243 // TODO: this currently assumes that there are no groups that could be in flight
244 // in the existing code.
245 static void
248  unsigned iteration, unsigned depth) {
249  // Based on the order of copies within the loop we need to set the number
250  // of copies in flight, unless it is already set.
251  auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
252  if (!waitOp || waitOp.getNumGroups())
253  return;
254 
255  int numGroupInFlight = 0;
258  numGroupInFlight = depth - 1;
259  } else {
260  // By construction there should be no wait op in the prologue as all the
261  // wait should be in the last stage.
263  // Based on the schedule we pick we know how many groups are in flight for
264  // each iteration of the epilogue.
265  numGroupInFlight = depth - 1 - iteration;
266  }
267  waitOp.setNumGroups(numGroupInFlight);
268 }
269 
270 /// Hook for the loop pipeliner that populates `ops` with the stage information
271 /// as follows:
272 ///
273 /// - operations in `stage0Ops` (typically loads from global memory and
274 /// related barriers) are at stage 0;
275 /// - operations in the backward slice of any stage0Ops are all at stage 0;
276 /// - other operations are at stage `depth`;
277 /// - the internal order of the pipelined loop has ops at stage `depth` first,
278 /// then those at stage 0, with relative order within each group preserved.
279 ///
280 static void getPipelineStages(
281  scf::ForOp forOp,
282  std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
283  unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
284  SetVector<Operation *> dependencies;
285  BackwardSliceOptions options([&](Operation *visited) {
286  return visited->getBlock() == forOp.getBody();
287  });
288  options.inclusive = true;
289  for (Operation &op : forOp.getBody()->getOperations()) {
290  if (stage0Ops.contains(&op)) {
291  LogicalResult result = getBackwardSlice(&op, &dependencies, options);
292  assert(result.succeeded() && "expected a backward slice");
293  (void)result;
294  }
295  }
296 
297  for (Operation &op : forOp.getBody()->getOperations()) {
298  if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
299  opsWithPipelineStages.emplace_back(&op, depth);
300  }
301  for (Operation &op : forOp.getBody()->getOperations()) {
302  if (dependencies.contains(&op))
303  opsWithPipelineStages.emplace_back(&op, 0);
304  }
305 }
306 
307 /// Hook for the loop pipeliner. Replaces op with a predicated version and
308 /// returns the resulting operation. Returns the original op if the predication
309 /// isn't necessary for the given op. Returns null if predication is needed but
310 /// not supported.
312  Operation *op, Value predicate) {
313  // Some operations may be fine to execute "speculatively" more times than the
314  // original number of iterations, in particular side-effect free operations
315  // and barriers, even if they cannot be predicated.
316  if (isMemoryEffectFree(op) ||
317  isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
318  nvgpu::DeviceAsyncWaitOp>(op)) {
319  return op;
320  }
321 
322  // Otherwise, only async copies can currently be predicated.
323  auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
324  if (!asyncCopyOp)
325  return nullptr;
326 
327  // Create srcElement Value based on `predicate`. The next lines generate
328  // the following code:
329  //
330  // srcElement = (pred) ? prevSrcElements : 0;
331  //
332  Location loc = asyncCopyOp->getLoc();
333  Value dstElements = arith::ConstantOp::create(
334  rewriter, loc, asyncCopyOp.getDstElementsAttr());
335  Value originalSrcElement =
336  asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
337  Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0);
338  auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
339  originalSrcElement, c0Index);
340  auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create(
341  rewriter, loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
342  asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
343  asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
344  UnitAttr());
345  rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
346  return asyncCopyZeroFillOp;
347 }
348 
349 /// Applies loop pipelining with the given depth to the given loop so that
350 /// copies into the shared memory are pipelined. Doesn't affect other loops.
351 /// Returns a pair containing the error state and the pipelined op, the latter
352 /// being null in case of any failure. The error state contains a definite error
353 /// if the IR has been modified and a silenceable error otherwise.
354 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
355 pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
356  bool epiloguePeeling) {
358  if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
359  return std::make_tuple(
360  emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
361  scf::ForOp());
362  }
363  if (stage0Ops.empty()) {
364  return std::make_tuple(
365  emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
366  }
367 
369  unsigned maxDepth = depth;
370  auto setAnnotation = [&](Operation *op,
372  unsigned iteration) {
373  return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
374  };
375  options.getScheduleFn =
376  [&](scf::ForOp schedulingFor,
377  std::vector<std::pair<Operation *, unsigned>> &ops) {
378  if (schedulingFor != forOp)
379  return;
380  return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
381  };
382  options.annotateFn = setAnnotation;
383  if (!epiloguePeeling) {
384  options.peelEpilogue = false;
385  options.predicateFn = replaceOpWithPredicatedOp;
386  }
387 
388  OpBuilder::InsertionGuard guard(rewriter);
389  rewriter.setInsertionPoint(forOp);
390  bool modifiedIR;
391  FailureOr<scf::ForOp> maybePipelined =
392  pipelineForLoop(rewriter, forOp, options, &modifiedIR);
393  if (succeeded(maybePipelined)) {
394  return std::make_tuple(DiagnosedSilenceableFailure::success(),
395  *maybePipelined);
396  }
397  return std::make_tuple(
398  modifiedIR
400  : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
401  scf::ForOp());
402 }
403 
404 DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
405  TransformRewriter &rewriter, scf::ForOp forOp,
406  ApplyToEachResultList &results, TransformState &state) {
407  auto [diag, pipelined] = pipelineForSharedCopies(
408  rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
409  if (diag.succeeded()) {
410  results.push_back(pipelined);
412  }
413  if (diag.isDefiniteFailure()) {
414  auto diag = emitDefiniteFailure("irreversible pipelining failure");
415  if (!getPeelEpilogue()) {
416  diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
417  diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
418  }
419  return diag;
420  }
421 
422  return std::move(diag);
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // RewriteMatmulAsMmaSyncOp
427 //===----------------------------------------------------------------------===//
428 
429 /// Helper struct to encode a pair of row/column indexings in the form of
430 /// affine expressions.
431 struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
433  : std::pair<AffineExpr, AffineExpr>(row, col) {}
434 
435  AffineExpr row() const { return first; };
436  AffineExpr col() const { return second; };
437 
438  void print(llvm::raw_ostream &os) const {
439  os << "- indexing: " << first << ", " << second;
440  }
441 };
442 
443 /// Helper struct to provide a simple mapping from matmul operations to the
444 /// corresponding mma.sync operation. This is constrained to the case where the
445 /// matmul matches the mma.sync operation 1-1.
448  : b(b), loc(loc), laneId(laneId) {}
449 
451  std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
452 
453  /// Create the mma.sync operation corresponding to `linalgOp` along with all
454  /// the supporting load/store and vector operations.
455  FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
456 
457 private:
458  struct MmaSyncInfo {
459  std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
460  std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
461  vectorShapes;
462  SmallVector<int64_t> mmaShape;
463  bool tf32Enabled;
464  };
465 
466  /// Return the specific index calculator for the given `linalgOp` or failure
467  /// if the op is not supported. This is the toplevel switch that should just
468  /// be Tablegen'd in the future.
469  FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
470  TypeRange elementalTypes);
471 
472  //===--------------------------------------------------------------------===//
473  // Instruction-specific row, column indexing expression builders.
474  // These should all be declaratively specified via Tablegen in the future.
475  // The Tablegen specification should be as straightforward as possible to
476  // only model the existing size and type combinations.
477  //===--------------------------------------------------------------------===//
478  //
479  // TODO: Tablegen all this.
480  //===--------------------------------------------------------------------===//
481  // m16n8k4 tf32 case.
482  //===--------------------------------------------------------------------===//
483  /// From the NVIDIA doc:
484  /// groupID = %laneid >> 2
485  /// threadIDInGroup = %laneid % 4
486  /// row = groupID for a0
487  /// groupID + 8 for a1
488  /// col = threadIDInGroup
489  static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
490  auto dim = getAffineDimExpr(0, ctx);
491  AffineExpr groupID = dim.floorDiv(4);
492  AffineExpr threadIDInGroup = dim % 4;
493  return {RowColIndexing{groupID, threadIDInGroup},
494  RowColIndexing{groupID + 8, threadIDInGroup}};
495  }
496 
497  /// From the NVIDIA doc:
498  /// groupID = %laneid >> 2
499  /// threadIDInGroup = %laneid % 4
500  /// row = threadIDInGroup
501  /// col = groupID
502  static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
503  auto dim = getAffineDimExpr(0, ctx);
504  AffineExpr groupID = dim.floorDiv(4);
505  AffineExpr threadIDInGroup = dim % 4;
506  return {RowColIndexing{threadIDInGroup, groupID}};
507  }
508 
509  /// From the NVIDIA doc:
510  /// groupID = %laneid >> 2
511  /// threadIDInGroup = %laneid % 4
512  /// row = groupID for c0 and c1
513  /// groupID + 8 for c2 and c3
514  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
515  static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
516  auto dim = getAffineDimExpr(0, ctx);
517  AffineExpr groupID = dim.floorDiv(4);
518  AffineExpr threadIDInGroup = dim % 4;
519  return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
520  RowColIndexing{groupID, threadIDInGroup * 2 + 1},
521  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
522  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
523  }
524 
525  //===--------------------------------------------------------------------===//
526  // m16n8k16 f16 case.
527  //===--------------------------------------------------------------------===//
528  /// From the NVIDIA doc:
529  /// groupID = %laneid >> 2
530  /// threadIDInGroup = %laneid % 4
531  ///
532  /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
533  /// groupID + 8 Otherwise
534  ///
535  /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
536  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
537  static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
538  auto dim = getAffineDimExpr(0, ctx);
539  AffineExpr groupID = dim.floorDiv(4);
540  AffineExpr threadIDInGroup = dim % 4;
541  // clang-format off
542  return {
543  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
544  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
545  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
546  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
547  RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
548  RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
549  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
550  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
551  };
552  // clang-format on
553  }
554 
555  /// From the NVIDIA doc:
556  /// groupID = %laneid >> 2
557  /// threadIDInGroup = %laneid % 4
558  ///
559  /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
560  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
561  ///
562  /// col = groupID
563  static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
564  auto dim = getAffineDimExpr(0, ctx);
565  AffineExpr groupID = dim.floorDiv(4);
566  AffineExpr threadIDInGroup = dim % 4;
567  // clang-format off
568  return {
569  RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
570  RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
571  RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
572  RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
573  };
574  // clang-format on
575  }
576 
577  /// From the NVIDIA doc:
578  /// groupID = %laneid >> 2
579  /// threadIDInGroup = %laneid % 4
580  ///
581  /// row = groupID for ci where i < 2
582  /// groupID + 8 for ci where i >= 2
583  ///
584  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
585  static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
586  auto dim = getAffineDimExpr(0, ctx);
587  AffineExpr groupID = dim.floorDiv(4);
588  AffineExpr threadIDInGroup = dim % 4;
589  // clang-format off
590  return {
591  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
592  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
593  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
594  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
595  };
596  // clang-format on
597  }
598 
599  //===--------------------------------------------------------------------===//
600  /// Helper functions to create customizable load and stores operations. The
601  /// specific shapes of each MMA instruction are passed via the
602  /// IndexCalculator callback.
603  //===--------------------------------------------------------------------===//
604  /// Build a list of memref.load operations indexed at `(row, col)` indices
605  /// that make sense for a particular MMA instruction and specified via the
606  /// IndexCalculator callback.
607  SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
608  OpFoldResult laneId, Value memref,
609  const IndexCalculator &indexFn);
610 
611  /// Perform a distributed load of a vector operand of `vectorShape` for a
612  /// particular MMA instruction whose `(row, col)` indices are specified via
613  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
614  /// data that makes sense for the particular MMA operation.
615  /// The `vectorShape` matches existing NVGPU dialect op specification but
616  /// could also be flattened in the future if needed for simplification.
617  Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
618  OpFoldResult laneId, Value memref,
619  IndexCalculator indexFn,
621 
622  /// Build a list of memref.store operations indexed at `(row, col)` indices
623  /// that make sense for a particular MMA instruction and specified via the
624  /// IndexCalculator callback.
625  SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
626  ValueRange toStore,
627  OpFoldResult laneId, Value memref,
628  const IndexCalculator &indexFn);
629 
630  /// Perform a distributed store of a vector operand of `vectorShape` for a
631  /// particular MMA instruction whose `(row, col)` indices are specified via
632  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
633  /// data that makes sense for the particular MMA operation.
634  /// The `vectorShape` matches existing NVGPU dialect op specification but
635  /// could also be flattened in the future if needed for simplification.
636  SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
637  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
638  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
639 
640  OpBuilder &b;
641  Location loc;
642  OpFoldResult laneId;
643 };
644 
645 //===--------------------------------------------------------------------===//
646 /// Helper functions to create customizable load and stores operations. The
647 /// specific shapes of each MMA instruction are passed via the
648 /// IndexCalculator callback.
649 //===--------------------------------------------------------------------===//
650 
651 template <typename ApplyFn, typename ReduceFn>
652 static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
653  ReduceFn reduceFn) {
654  VectorType vectorType = cast<VectorType>(vector.getType());
655  auto vectorShape = vectorType.getShape();
656  auto strides = computeStrides(vectorShape);
657  for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
658  auto indices = delinearize(idx, strides);
659  reduceFn(applyFn(vector, idx, indices), idx, indices);
660  }
661 }
662 
664 MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
665  OpFoldResult laneId, Value memref,
666  const IndexCalculator &indexFn) {
667  auto aff = [&](AffineExpr e) {
668  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
669  };
670  SmallVector<Value> res;
671  SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
672  for (auto indexing : indexings) {
673  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
674  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
675  auto load = memref::LoadOp::create(b, loc, memref, ValueRange{row, col});
676  res.push_back(load);
677  }
678  return res;
679 }
680 
681 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
682  OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
683  IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
684  auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
685 
686  Type elementType = getElementTypeOrSelf(memref.getType());
687  auto vt = VectorType::get(vectorShape, elementType);
688  Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
690  res,
691  /*applyFn=*/
692  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
693  return loads[linearIdx];
694  },
695  /*reduceFn=*/
696  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
697  res = vector::InsertOp::create(b, loc, v, res, indices);
698  });
699 
700  return res;
701 }
702 
703 SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
704  OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
705  Value memref, const IndexCalculator &indexFn) {
706  auto aff = [&](AffineExpr e) {
707  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
708  };
710  for (auto [indexing, val] :
711  llvm::zip_equal(indexFn(b.getContext()), toStore)) {
712  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
713  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
714  Operation *store =
715  memref::StoreOp::create(b, loc, val, memref, ValueRange{row, col});
716  res.push_back(store);
717  }
718  return res;
719 }
720 
721 SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
722  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
723  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
724  SmallVector<Value> toStore;
725  toStore.reserve(32);
727  vectorToStore,
728  /*applyFn=*/
729  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
730  return vector::ExtractOp::create(b, loc, vectorToStore, indices);
731  },
732  /*reduceFn=*/
733  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
734  toStore.push_back(v);
735  });
736  return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
737 }
738 
739 static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
742  ArrayRef<int64_t> res) {
743  SmallVector<int64_t> vlhs(lhs);
744  SmallVector<int64_t> vrhs(rhs);
745  SmallVector<int64_t> vres(res);
746  return std::make_tuple(vlhs, vrhs, vres);
747 }
748 
749 FailureOr<MmaSyncBuilder::MmaSyncInfo>
750 MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
751  TypeRange elementalTypes) {
752  // TODO: Tablegen all this.
753  Type f16 = b.getF16Type();
754  Type f32 = b.getF32Type();
755  if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
756  elementalTypes == TypeRange{f32, f32, f32}) {
757  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
758  &MmaSyncBuilder::m16n8k4tf32Rhs,
759  &MmaSyncBuilder::m16n8k4tf32Res),
760  makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
761  SmallVector<int64_t>{opShape},
762  /*tf32Enabled=*/true};
763  }
764  // This is the version with f16 accumulation.
765  // TODO: version with f32 accumulation.
766  if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
767  elementalTypes == TypeRange{f16, f16, f16}) {
768  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
769  &MmaSyncBuilder::m16n8k16f16Rhs,
770  &MmaSyncBuilder::m16n8k16f16Res),
771  makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
772  SmallVector<int64_t>{opShape},
773  /*tf32Enabled=*/false};
774  }
775  return failure();
776 }
777 
778 FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
779  Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
780  Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
781  Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
782  assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
783  "expected lhs to be a 2D memref");
784  assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
785  "expected rhs to be a 2D memref");
786  assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
787  "expected res to be a 2D memref");
788 
789  int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
790  int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
791  int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
792  Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
793  Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
794  Type resType = getElementTypeOrSelf(resMemRef.getType());
795 
796  FailureOr<MmaSyncInfo> maybeInfo =
797  getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
798  if (failed(maybeInfo))
799  return failure();
800 
801  MmaSyncInfo info = *maybeInfo;
802  auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
803  auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
804  Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
805  lhsIndexFn, lhsShape);
806  Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
807  rhsIndexFn, rhsShape);
808  Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
809  resIndexFn, resShape);
810  res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape,
811  info.tf32Enabled);
812  buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
813  resShape);
814  return res.getDefiningOp();
815 }
816 
817 DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
818  transform::TransformRewriter &rewriter, LinalgOp linalgOp,
820  transform::TransformState &state) {
821  bool fail = true;
822  // TODO: more robust detection of matmulOp, with transposes etc.
823  if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
824  // Check to not let go the matmul with extended semantic, through this
825  // transform.
826  if (linalgOp.hasUserDefinedMaps()) {
827  return emitSilenceableError()
828  << "only matmul ops with non-extended semantics are supported";
829  }
830  Location loc = linalgOp.getLoc();
831  // TODO: more robust computation of laneId, for now assume a single warp.
832  Value laneId = gpu::ThreadIdOp::create(
833  rewriter, loc, rewriter.getIndexType(), gpu::Dimension::x);
834  if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
835  fail = false;
836  }
837 
838  if (fail) {
839  DiagnosedSilenceableFailure diag = emitSilenceableError()
840  << "unsupported target op: " << linalgOp;
841  diag.attachNote(linalgOp->getLoc()) << "target op";
842  return diag;
843  }
844 
845  rewriter.eraseOp(linalgOp);
847 }
848 
849 //===----------------------------------------------------------------------===//
850 // Hopper builders.
851 //===----------------------------------------------------------------------===//
852 
853 /// Helper to create the base Hopper-specific operations that are reused in
854 /// various other places.
857  : rewriter(rewriter), loc(loc) {}
858 
860  buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
861 
862  /// Create tma descriptor op to initiate transfer from global to shared
863  /// memory. This must be done before the launch op, on the host.
865  buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
866  gpu::LaunchOp launchOp);
867 
868  /// Build a tma load from global memory to shared memory using `barrier` to
869  /// synchronize. Return the number of bytes that will be transferred.
871  buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
872  TypedValue<MemRefType> sharedMemref,
875  void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
876  ArrayRef<OpFoldResult> sizes);
877 
878  /// If threadIdx.x == 0 does TMA request + wait, else just wait.
879  /// Return the operation that performs the transfer on thread0.
880  // TODO: In the future, don't hardcode to thread 0 but elect a leader.
881  SmallVector<Operation *> buildPredicateLoadsOnThread0(
883  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
885 
886  void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
887 
890 };
891 
893  ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
894  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
896  SmallVector<Operation *> loadOps;
897  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
898  Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
899  Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
900  tidx, zero);
901  // clang-format off
902  scf::IfOp::create(rewriter,
903  /*location=*/loc,
904  /*conditional=*/cond,
905  /*thenBuilder=*/
906  [&](OpBuilder &lb, Location loc) {
908  sizes.reserve(globalDescriptors.size());
909  for (auto [desc, shmem] : llvm::zip_equal(
910  globalDescriptors, sharedMemBuffers)) {
911  OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
912  sizes.push_back(sz);
913  }
914  // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
915  // This may or may not have perf implications.
916  buildBarrierArriveTx(barrier, sizes);
917  scf::YieldOp::create(rewriter, loc);
918  },
919  /*elseBuilder=*/
920  [&](OpBuilder &lb, Location loc) {
921  // TODO: is this for no-thread divergence?
922  // Should we just yield the size and hoist?
923  buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
924  scf::YieldOp::create(rewriter, loc);
925  });
926  // clang-format on
927  return loadOps;
928 }
929 
932  b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
933  // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
934 }
935 
938  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
939  Value barrier = nvgpu::MBarrierCreateOp::create(
940  rewriter, loc,
941  nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
942  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
943  nvgpu::MBarrierInitOp::create(
944  rewriter, loc, barrier,
945  getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero,
946  Value());
947  gpu::BarrierOp::create(rewriter, loc);
948  return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
949 }
950 
953  gpu::LaunchOp launchOp) {
954  OpBuilder::InsertionGuard guard(rewriter);
955  rewriter.setInsertionPoint(launchOp);
956  Value unrankedMemRef = memref::CastOp::create(
957  rewriter, loc,
958  UnrankedMemRefType::get(memref.getType().getElementType(),
959  memref.getType().getMemorySpace()),
960  memref);
961  SmallVector<OpFoldResult> mixedSizes =
962  memref::getMixedSizes(rewriter, loc, memref);
963  SmallVector<Value> sizes =
964  getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
965 
966  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
967  Value desc = nvgpu::TmaCreateDescriptorOp::create(
968  rewriter, loc,
970  rewriter.getContext(),
971  MemRefType::Builder(memref.getType())
972  .setMemorySpace(sharedMemorySpace),
973  TensorMapSwizzleKind::SWIZZLE_NONE,
974  TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
975  TensorMapInterleaveKind::INTERLEAVE_NONE),
976  unrankedMemRef, sizes);
977  return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
978 }
979 
982  TypedValue<MemRefType> sharedMemref,
984  SmallVectorImpl<Operation *> &loadOps) {
985  MLIRContext *ctx = rewriter.getContext();
986  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
987  Operation *loadOp = nvgpu::TmaAsyncLoadOp::create(
988  rewriter, loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero},
989  zero, Value(), Value());
990  loadOps.push_back(loadOp);
991  auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
992  SmallVector<AffineExpr> symbols(mixedSizes.size());
993  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
994  AffineExpr prodExprInBytes =
995  computeProduct(ctx, symbols) *
996  (sharedMemref.getType().getElementTypeBitWidth() / 8);
997  auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
998  prodExprInBytes, mixedSizes);
999  return res;
1000 }
1001 
1004  ArrayRef<OpFoldResult> mixedSizes) {
1005  assert(!mixedSizes.empty() && "expecte non-empty sizes");
1006  MLIRContext *ctx = rewriter.getContext();
1007  SmallVector<AffineExpr> symbols(mixedSizes.size());
1008  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
1009  AffineExpr sumExpr = computeSum(ctx, symbols);
1010  OpFoldResult size =
1011  affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
1012  Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1013  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1014  nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero,
1015  Value());
1016 }
1017 
1020  Type i1 = rewriter.getI1Type();
1021  Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
1022  // 10M is an arbitrary, not too small or too big number to specify the number
1023  // of ticks before retry.
1024  // TODO: hoist this in a default dialect constant.
1025  Value ticksBeforeRetry =
1026  arith::ConstantIndexOp::create(rewriter, loc, 10000000);
1027  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1028  nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity,
1029  ticksBeforeRetry, zero);
1030 }
1031 
1032 //===----------------------------------------------------------------------===//
1033 // RewriteCopyAsTmaOp
1034 //===----------------------------------------------------------------------===//
1035 
1036 /// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1037 struct CopyBuilder : public HopperBuilder {
1039  : HopperBuilder(rewriter, loc) {}
1040 
1042 };
1043 
1044 SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
1045  MLIRContext *ctx = rewriter.getContext();
1046  if (copyOps.empty())
1047  return SmallVector<Operation *>();
1048 
1049  auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1050  assert(launchOp && "expected launch op");
1051 
1052  // 1. Init a barrier object in shared memory.
1053  OpBuilder::InsertionGuard g(rewriter);
1054  rewriter.setInsertionPoint(copyOps.front());
1055  AffineExpr bx, by, bz;
1056  bindSymbols(ctx, bx, by, bz);
1057  AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
1059  rewriter, loc, prod,
1060  ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1061  launchOp.getBlockSizeZ()});
1062 
1064  buildAndInitBarrierInSharedMemory(numThreads);
1065 
1068  for (Operation *op : copyOps) {
1069  auto copyOp = cast<linalg::CopyOp>(op);
1070  auto inMemRef =
1071  cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1072  assert(inMemRef.getType().getRank() == 2 &&
1073  "expected in to be a 2D memref");
1074 
1075  // 2. Build global memory descriptor.
1077  buildGlobalMemRefDescriptor(inMemRef, launchOp);
1078  globalDescs.push_back(globalDesc);
1079 
1080  // 3. Shared memory and descriptor for the tmp array.
1081  auto shmem =
1082  cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1083  shmems.push_back(shmem);
1084  }
1085 
1086  // 4. Load in from global memory to shared memory using tma.
1087  OpBuilder::InsertionGuard g2(rewriter);
1088  rewriter.setInsertionPoint(copyOps.front());
1089  SmallVector<Operation *> results =
1090  buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1091 
1092  // 5. Spin-loop until data is ready.
1093  buildTryWaitParity(barrier);
1094 
1095  // 6. Erase the ops that have now been rewritten.
1096  for (Operation *op : copyOps)
1097  rewriter.eraseOp(op);
1098 
1099  return results;
1100 }
1101 
1103 transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
1104  transform::TransformResults &results,
1105  transform::TransformState &state) {
1106  auto payloadOps = state.getPayloadOps(getTarget());
1107  gpu::LaunchOp commonLaunchOp;
1108  Operation *firstOp, *failingOp;
1109  if (llvm::any_of(payloadOps, [&](Operation *op) {
1110  if (!commonLaunchOp) {
1111  commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1112  firstOp = op;
1113  }
1114  auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1115  commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1116  !isa<linalg::CopyOp>(op);
1117  if (fail)
1118  failingOp = op;
1119  return fail;
1120  })) {
1122  emitSilenceableError()
1123  << "target ops must be linalg::CopyOp nested under a common "
1124  "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1125  "be created on the host.\nBut got: "
1126  << *firstOp << "\nand " << *failingOp;
1127  return diag;
1128  }
1129 
1130  // TODO: more robust detection of copy, with transposes etc.
1131  CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1132 
1134 }
1135 
1136 //===----------------------------------------------------------------------===//
1137 // Transform op registration
1138 //===----------------------------------------------------------------------===//
1139 
1140 namespace {
1141 class NVGPUTransformDialectExtension
1143  NVGPUTransformDialectExtension> {
1144 public:
1145  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
1146 
1147  NVGPUTransformDialectExtension() {
1148  declareGeneratedDialect<arith::ArithDialect>();
1149  declareGeneratedDialect<affine::AffineDialect>();
1150  declareGeneratedDialect<nvgpu::NVGPUDialect>();
1151  declareGeneratedDialect<NVVM::NVVMDialect>();
1152  declareGeneratedDialect<vector::VectorDialect>();
1153  registerTransformOps<
1154 #define GET_OP_LIST
1155 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1156  >();
1157  }
1158 };
1159 } // namespace
1160 
1161 #define GET_OP_CLASSES
1162 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1163 
1165  registry.addExtensions<NVGPUTransformDialectExtension>();
1166 }
static constexpr int64_t kSharedMemorySpace
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:40
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)
static bool hasDefaultMemorySpace(BaseMemRefType type)
Returns true if the given type has the default memory space.
static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)
Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...
static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)
Hook for the loop pipeliner.
static bool isStoreToShared(Operation *op, Value v)
Returns true if the operation is storing the given value into shared memory.
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)
Helper functions to create customizable load and stores operations.
static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)
Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...
static bool hasSharedMemorySpace(BaseMemRefType type)
Returns true if the given type has the shared (workgroup) memory space.
static bool isLoadFromGlobalStoredToShared(Operation *op)
Returns true if the operation is a load from the default memory space the result of which is only sto...
static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)
static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)
Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...
static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)
Hook for the loop pipeliner that populates ops with the stage information as follows:
static Value getValueLoadedFromGlobal(Operation *op)
Returns the value produced by a load from the default memory space.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:959
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
FloatType getF32Type()
Definition: Builders.cpp:42
FloatType getF16Type()
Definition: Builders.cpp:38
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:208
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:42
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry &registry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:330
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
RewriterBase & rewriter
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
AffineExpr col() const
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
AffineExpr row() const
Options to dictate how loops should be pipelined.
Definition: Transforms.h:129