MLIR 22.0.0git
NVGPUTransformOps.cpp
Go to the documentation of this file.
1//===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
29#include "mlir/IR/AffineExpr.h"
31#include "mlir/IR/Value.h"
32#include "llvm/ADT/ArrayRef.h"
33
34using namespace mlir;
35using namespace mlir::linalg;
36using namespace mlir::nvgpu;
37using namespace mlir::NVVM;
38using namespace mlir::transform;
39
40#define DEBUG_TYPE "nvgpu-transforms"
41
42//===----------------------------------------------------------------------===//
43// Apply...ConversionPatternsOp
44//===----------------------------------------------------------------------===//
45
46void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
47 TypeConverter &typeConverter, RewritePatternSet &patterns) {
48 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
49 /// device-side async tokens cannot be materialized in nvvm. We just
50 /// convert them to a dummy i32 type in order to easily drop them during
51 /// conversion.
53 llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
54 switch (space) {
55 case gpu::AddressSpace::Global:
56 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
57 case gpu::AddressSpace::Workgroup:
58 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
59 case gpu::AddressSpace::Private:
60 return 0;
61 }
62 llvm_unreachable("unknown address space enum value");
63 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
64 });
65 llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) -> Type {
66 return llvmTypeConverter.convertType(
67 IntegerType::get(type.getContext(), 32));
68 });
69 llvmTypeConverter.addConversion([&](MBarrierTokenType type) -> Type {
70 return llvmTypeConverter.convertType(
71 IntegerType::get(type.getContext(), 64));
72 });
73 llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) -> Type {
74 Type elemType = type.getFragmented().getElementType();
75 int64_t sizeM = type.getFragmented().getDimSize(0);
76 int64_t sizeN = type.getFragmented().getDimSize(1);
77
78 unsigned numMembers;
79 if (elemType.isF32() || elemType.isInteger(32))
80 numMembers = sizeN / 2;
81 else if (elemType.isF16())
82 numMembers = sizeN / 4;
83 else
84 llvm_unreachable("unsupported type for warpgroup accumulator");
85
86 SmallVector<Type> innerStructBody;
87 for (unsigned i = 0; i < numMembers; i++)
88 innerStructBody.push_back(elemType);
89 auto innerStructType =
90 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
91
92 SmallVector<Type> structBody;
93 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
94 structBody.push_back(innerStructType);
95
96 auto convertedType =
97 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
98 return llvmTypeConverter.convertType(convertedType);
99 });
100 llvmTypeConverter.addConversion([&](MBarrierGroupType type) -> Type {
101 return llvmTypeConverter.convertType(
102 getMBarrierMemrefType(type.getContext(), type));
103 });
104 llvmTypeConverter.addConversion(
105 [&](WarpgroupMatrixDescriptorType type) -> Type {
106 return llvmTypeConverter.convertType(
107 IntegerType::get(type.getContext(), 64));
108 });
109 llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) -> Type {
110 return LLVM::LLVMPointerType::get(type.getContext());
111 });
113}
114
115LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
116 TypeConverterBuilderOpInterface builder) {
117 if (builder.getTypeConverterType() != "LLVMTypeConverter")
118 return emitOpError("expected LLVMTypeConverter");
119 return success();
120}
121
122//===---------------------------------------------------------------------===//
123// CreateAsyncGroupsOp
124//===---------------------------------------------------------------------===//
125
126void CreateAsyncGroupsOp::getEffects(
128 consumesHandle(getTargetMutable(), effects);
129 producesHandle(getOperation()->getOpResults(), effects);
130 modifiesPayload(effects);
131}
132
134CreateAsyncGroupsOp::applyToOne(TransformRewriter &rewriter, Operation *target,
135 ApplyToEachResultList &results,
136 TransformState &state) {
137 createAsyncGroups(rewriter, target, getBypassL1());
138 results.push_back(target);
140}
141
142//===----------------------------------------------------------------------===//
143// PipelineSharedMemoryCopiesOp
144//===----------------------------------------------------------------------===//
145
146/// Returns true if the given type has the default memory space.
148 return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
149}
150
151/// Returns true if the given type has the shared (workgroup) memory space.
153 auto space =
154 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
155 return space &&
156 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
157}
158
159/// Returns the value produced by a load from the default memory space. Returns
160/// null if the operation is not such a load.
162 // TODO: consider an interface or leveraging the memory effects interface.
163 auto load = dyn_cast<vector::TransferReadOp>(op);
164 if (!load)
165 return nullptr;
166
167 auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
168 if (!loadType || !hasDefaultMemorySpace(loadType))
169 return nullptr;
170 return load;
171}
172
173/// Returns true if the operation is storing the given value into shared memory.
174static bool isStoreToShared(Operation *op, Value v) {
175 // TOD: consider an interface or leveraging the memory effects interface.
176 auto store = dyn_cast<vector::TransferWriteOp>(op);
177 if (!store || store.getVector() != v)
178 return false;
179
180 auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
181 return storeType || hasSharedMemorySpace(storeType);
182}
183
184/// Returns true if the operation is a load from the default memory space the
185/// result of which is only stored into the shared memory space.
187 Value loaded = getValueLoadedFromGlobal(op);
188 if (!loaded || !loaded.hasOneUse())
189 return false;
190
191 return isStoreToShared(*loaded.getUsers().begin(), loaded);
192}
193
194/// Populate `ops` with the set of operations that belong to the stage 0 of the
195/// pipelined version of the given loop when pipelining copies to shared memory.
196/// Specifically, this collects:
197///
198/// 1. all loads from global memory, both sync and async;
199/// 2. the barriers for async loads.
200///
201/// In particular, barriers are omitted if they do not dominate at least one
202/// async load for which there is not yet a barrier.
203static LogicalResult
206
208 for (Operation &op : *forOp.getBody()) {
209 // Bail on nested ops for now.
210 if (op.getNumRegions() > 0)
211 return failure();
212
213 if (isa<gpu::BarrierOp>(op)) {
214 barriers.insert(&op);
215 continue;
216 }
217
218 if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) {
219 ops.insert(&op);
220 ops.insert(std::make_move_iterator(barriers.begin()),
221 std::make_move_iterator(barriers.end()));
222 assert(barriers.empty() &&
223 "expected to have moved the barriers into another set");
224 continue;
225 }
226
228 ops.insert(&op);
229 continue;
230 }
231 }
232
233 return success();
234}
235
236/// Hook for the loop pipeliner that sets the "num groups in flight" attribute
237/// of async wait operations corresponding to pipelined shared memory copies.
238// TODO: this currently assumes that there are no groups that could be in flight
239// in the existing code.
240static void
243 unsigned iteration, unsigned depth) {
244 // Based on the order of copies within the loop we need to set the number
245 // of copies in flight, unless it is already set.
246 auto waitOp = dyn_cast<DeviceAsyncWaitOp>(op);
247 if (!waitOp || waitOp.getNumGroups())
248 return;
249
250 int numGroupInFlight = 0;
253 numGroupInFlight = depth - 1;
254 } else {
255 // By construction there should be no wait op in the prologue as all the
256 // wait should be in the last stage.
258 // Based on the schedule we pick we know how many groups are in flight for
259 // each iteration of the epilogue.
260 numGroupInFlight = depth - 1 - iteration;
261 }
262 waitOp.setNumGroups(numGroupInFlight);
263}
264
265/// Hook for the loop pipeliner that populates `ops` with the stage information
266/// as follows:
267///
268/// - operations in `stage0Ops` (typically loads from global memory and
269/// related barriers) are at stage 0;
270/// - operations in the backward slice of any stage0Ops are all at stage 0;
271/// - other operations are at stage `depth`;
272/// - the internal order of the pipelined loop has ops at stage `depth` first,
273/// then those at stage 0, with relative order within each group preserved.
274///
276 scf::ForOp forOp,
277 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
278 unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
279 SetVector<Operation *> dependencies;
281 return visited->getBlock() == forOp.getBody();
282 });
283 options.inclusive = true;
284 for (Operation &op : forOp.getBody()->getOperations()) {
285 if (stage0Ops.contains(&op)) {
286 LogicalResult result = getBackwardSlice(&op, &dependencies, options);
287 assert(result.succeeded() && "expected a backward slice");
288 (void)result;
289 }
290 }
291
292 for (Operation &op : forOp.getBody()->getOperations()) {
293 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
294 opsWithPipelineStages.emplace_back(&op, depth);
295 }
296 for (Operation &op : forOp.getBody()->getOperations()) {
297 if (dependencies.contains(&op))
298 opsWithPipelineStages.emplace_back(&op, 0);
299 }
300}
301
302/// Hook for the loop pipeliner. Replaces op with a predicated version and
303/// returns the resulting operation. Returns the original op if the predication
304/// isn't necessary for the given op. Returns null if predication is needed but
305/// not supported.
307 Operation *op, Value predicate) {
308 // Some operations may be fine to execute "speculatively" more times than the
309 // original number of iterations, in particular side-effect free operations
310 // and barriers, even if they cannot be predicated.
311 if (isMemoryEffectFree(op) ||
312 isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) {
313 return op;
314 }
315
316 // Otherwise, only async copies can currently be predicated.
317 auto asyncCopyOp = dyn_cast<DeviceAsyncCopyOp>(op);
318 if (!asyncCopyOp)
319 return nullptr;
320
321 // Create srcElement Value based on `predicate`. The next lines generate
322 // the following code:
323 //
324 // srcElement = (pred) ? prevSrcElements : 0;
325 //
326 Location loc = asyncCopyOp->getLoc();
327 Value dstElements = arith::ConstantOp::create(
328 rewriter, loc, asyncCopyOp.getDstElementsAttr());
329 Value originalSrcElement =
330 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
331 Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0);
332 auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
333 originalSrcElement, c0Index);
334 auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create(
335 rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
336 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
337 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
338 UnitAttr());
339 rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
340 return asyncCopyZeroFillOp;
341}
342
343/// Applies loop pipelining with the given depth to the given loop so that
344/// copies into the shared memory are pipelined. Doesn't affect other loops.
345/// Returns a pair containing the error state and the pipelined op, the latter
346/// being null in case of any failure. The error state contains a definite error
347/// if the IR has been modified and a silenceable error otherwise.
348static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
349pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
350 bool epiloguePeeling) {
352 if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
353 return std::make_tuple(
354 emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
355 scf::ForOp());
356 }
357 if (stage0Ops.empty()) {
358 return std::make_tuple(
359 emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
360 }
361
363 unsigned maxDepth = depth;
364 auto setAnnotation = [&](Operation *op,
366 unsigned iteration) {
367 return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
368 };
369 options.getScheduleFn =
370 [&](scf::ForOp schedulingFor,
371 std::vector<std::pair<Operation *, unsigned>> &ops) {
372 if (schedulingFor != forOp)
373 return;
374 return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
375 };
376 options.annotateFn = setAnnotation;
377 if (!epiloguePeeling) {
378 options.peelEpilogue = false;
380 }
381
382 OpBuilder::InsertionGuard guard(rewriter);
383 rewriter.setInsertionPoint(forOp);
384 bool modifiedIR;
385 FailureOr<scf::ForOp> maybePipelined =
386 pipelineForLoop(rewriter, forOp, options, &modifiedIR);
387 if (succeeded(maybePipelined)) {
388 return std::make_tuple(DiagnosedSilenceableFailure::success(),
389 *maybePipelined);
390 }
391 return std::make_tuple(
392 modifiedIR
394 : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
395 scf::ForOp());
396}
397
398DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
399 TransformRewriter &rewriter, scf::ForOp forOp,
400 ApplyToEachResultList &results, TransformState &state) {
401 auto [diag, pipelined] = pipelineForSharedCopies(
402 rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
403 if (diag.succeeded()) {
404 results.push_back(pipelined);
406 }
407 if (diag.isDefiniteFailure()) {
408 auto diag = emitDefiniteFailure("irreversible pipelining failure");
409 if (!getPeelEpilogue()) {
410 diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
411 diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
412 }
413 return diag;
414 }
415
416 return std::move(diag);
417}
418
419//===----------------------------------------------------------------------===//
420// RewriteMatmulAsMmaSyncOp
421//===----------------------------------------------------------------------===//
422
423/// Helper struct to encode a pair of row/column indexings in the form of
424/// affine expressions.
425struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
428
429 AffineExpr row() const { return first; };
430 AffineExpr col() const { return second; };
431
432 void print(llvm::raw_ostream &os) const {
433 os << "- indexing: " << first << ", " << second;
434 }
435};
436
437/// Helper struct to provide a simple mapping from matmul operations to the
438/// corresponding mma.sync operation. This is constrained to the case where the
439/// matmul matches the mma.sync operation 1-1.
442 : b(b), loc(loc), laneId(laneId) {}
443
445 std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
446
447 /// Create the mma.sync operation corresponding to `linalgOp` along with all
448 /// the supporting load/store and vector operations.
449 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
450
451private:
452 struct MmaSyncInfo {
453 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
454 std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
455 vectorShapes;
456 SmallVector<int64_t> mmaShape;
457 bool tf32Enabled;
458 };
459
460 /// Return the specific index calculator for the given `linalgOp` or failure
461 /// if the op is not supported. This is the toplevel switch that should just
462 /// be Tablegen'd in the future.
463 FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
464 TypeRange elementalTypes);
465
466 //===--------------------------------------------------------------------===//
467 // Instruction-specific row, column indexing expression builders.
468 // These should all be declaratively specified via Tablegen in the future.
469 // The Tablegen specification should be as straightforward as possible to
470 // only model the existing size and type combinations.
471 //===--------------------------------------------------------------------===//
472 //
473 // TODO: Tablegen all this.
474 //===--------------------------------------------------------------------===//
475 // m16n8k4 tf32 case.
476 //===--------------------------------------------------------------------===//
477 /// From the NVIDIA doc:
478 /// groupID = %laneid >> 2
479 /// threadIDInGroup = %laneid % 4
480 /// row = groupID for a0
481 /// groupID + 8 for a1
482 /// col = threadIDInGroup
483 static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
484 auto dim = getAffineDimExpr(0, ctx);
485 AffineExpr groupID = dim.floorDiv(4);
486 AffineExpr threadIDInGroup = dim % 4;
487 return {RowColIndexing{groupID, threadIDInGroup},
488 RowColIndexing{groupID + 8, threadIDInGroup}};
489 }
490
491 /// From the NVIDIA doc:
492 /// groupID = %laneid >> 2
493 /// threadIDInGroup = %laneid % 4
494 /// row = threadIDInGroup
495 /// col = groupID
496 static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
497 auto dim = getAffineDimExpr(0, ctx);
498 AffineExpr groupID = dim.floorDiv(4);
499 AffineExpr threadIDInGroup = dim % 4;
500 return {RowColIndexing{threadIDInGroup, groupID}};
501 }
502
503 /// From the NVIDIA doc:
504 /// groupID = %laneid >> 2
505 /// threadIDInGroup = %laneid % 4
506 /// row = groupID for c0 and c1
507 /// groupID + 8 for c2 and c3
508 /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
509 static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
510 auto dim = getAffineDimExpr(0, ctx);
511 AffineExpr groupID = dim.floorDiv(4);
512 AffineExpr threadIDInGroup = dim % 4;
513 return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
514 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
515 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
516 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
517 }
518
519 //===--------------------------------------------------------------------===//
520 // m16n8k16 f16 case.
521 //===--------------------------------------------------------------------===//
522 /// From the NVIDIA doc:
523 /// groupID = %laneid >> 2
524 /// threadIDInGroup = %laneid % 4
525 ///
526 /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
527 /// groupID + 8 Otherwise
528 ///
529 /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
530 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
531 static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
532 auto dim = getAffineDimExpr(0, ctx);
533 AffineExpr groupID = dim.floorDiv(4);
534 AffineExpr threadIDInGroup = dim % 4;
535 // clang-format off
536 return {
537 RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
538 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
539 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
540 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
541 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
542 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
543 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
544 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
545 };
546 // clang-format on
547 }
548
549 /// From the NVIDIA doc:
550 /// groupID = %laneid >> 2
551 /// threadIDInGroup = %laneid % 4
552 ///
553 /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
554 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
555 ///
556 /// col = groupID
557 static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
558 auto dim = getAffineDimExpr(0, ctx);
559 AffineExpr groupID = dim.floorDiv(4);
560 AffineExpr threadIDInGroup = dim % 4;
561 // clang-format off
562 return {
563 RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
564 RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
565 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
566 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
567 };
568 // clang-format on
569 }
570
571 /// From the NVIDIA doc:
572 /// groupID = %laneid >> 2
573 /// threadIDInGroup = %laneid % 4
574 ///
575 /// row = groupID for ci where i < 2
576 /// groupID + 8 for ci where i >= 2
577 ///
578 /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
579 static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
580 auto dim = getAffineDimExpr(0, ctx);
581 AffineExpr groupID = dim.floorDiv(4);
582 AffineExpr threadIDInGroup = dim % 4;
583 // clang-format off
584 return {
585 RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
586 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
587 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
588 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
589 };
590 // clang-format on
591 }
592
593 //===--------------------------------------------------------------------===//
594 /// Helper functions to create customizable load and stores operations. The
595 /// specific shapes of each MMA instruction are passed via the
596 /// IndexCalculator callback.
597 //===--------------------------------------------------------------------===//
598 /// Build a list of memref.load operations indexed at `(row, col)` indices
599 /// that make sense for a particular MMA instruction and specified via the
600 /// IndexCalculator callback.
601 SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
602 OpFoldResult laneId, Value memref,
603 const IndexCalculator &indexFn);
604
605 /// Perform a distributed load of a vector operand of `vectorShape` for a
606 /// particular MMA instruction whose `(row, col)` indices are specified via
607 /// the IndexCalculator callback. Each `laneId` loads the subportion of the
608 /// data that makes sense for the particular MMA operation.
609 /// The `vectorShape` matches existing NVGPU dialect op specification but
610 /// could also be flattened in the future if needed for simplification.
611 Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
612 OpFoldResult laneId, Value memref,
613 IndexCalculator indexFn,
614 ArrayRef<int64_t> vectorShape);
615
616 /// Build a list of memref.store operations indexed at `(row, col)` indices
617 /// that make sense for a particular MMA instruction and specified via the
618 /// IndexCalculator callback.
619 SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
620 ValueRange toStore,
621 OpFoldResult laneId, Value memref,
622 const IndexCalculator &indexFn);
623
624 /// Perform a distributed store of a vector operand of `vectorShape` for a
625 /// particular MMA instruction whose `(row, col)` indices are specified via
626 /// the IndexCalculator callback. Each `laneId` loads the subportion of the
627 /// data that makes sense for the particular MMA operation.
628 /// The `vectorShape` matches existing NVGPU dialect op specification but
629 /// could also be flattened in the future if needed for simplification.
630 SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
631 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
632 Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
633
634 OpBuilder &b;
635 Location loc;
636 OpFoldResult laneId;
637};
638
639//===--------------------------------------------------------------------===//
640/// Helper functions to create customizable load and stores operations. The
641/// specific shapes of each MMA instruction are passed via the
642/// IndexCalculator callback.
643//===--------------------------------------------------------------------===//
644
645template <typename ApplyFn, typename ReduceFn>
646static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
647 ReduceFn reduceFn) {
648 VectorType vectorType = cast<VectorType>(vector.getType());
649 auto vectorShape = vectorType.getShape();
650 auto strides = computeStrides(vectorShape);
651 for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
652 auto indices = delinearize(idx, strides);
653 reduceFn(applyFn(vector, idx, indices), idx, indices);
654 }
655}
656
658MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
659 OpFoldResult laneId, Value memref,
660 const IndexCalculator &indexFn) {
661 auto aff = [&](AffineExpr e) {
662 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
663 };
664 SmallVector<Value> res;
665 SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
666 for (auto indexing : indexings) {
667 Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
668 Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
669 auto load = memref::LoadOp::create(b, loc, memref, ValueRange{row, col});
670 res.push_back(load);
671 }
672 return res;
673}
674
675Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
677 IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
678 auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
679
680 Type elementType = getElementTypeOrSelf(memref.getType());
681 auto vt = VectorType::get(vectorShape, elementType);
682 Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
684 res,
685 /*applyFn=*/
686 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
687 return loads[linearIdx];
688 },
689 /*reduceFn=*/
690 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
691 res = vector::InsertOp::create(b, loc, v, res, indices);
692 });
693
694 return res;
695}
696
697SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
698 OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
699 Value memref, const IndexCalculator &indexFn) {
700 auto aff = [&](AffineExpr e) {
701 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
702 };
703 SmallVector<Operation *> res;
704 for (auto [indexing, val] :
705 llvm::zip_equal(indexFn(b.getContext()), toStore)) {
706 Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
707 Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
708 Operation *store =
709 memref::StoreOp::create(b, loc, val, memref, ValueRange{row, col});
710 res.push_back(store);
711 }
712 return res;
713}
714
715SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
716 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
717 Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
718 SmallVector<Value> toStore;
719 toStore.reserve(32);
721 vectorToStore,
722 /*applyFn=*/
723 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
724 return vector::ExtractOp::create(b, loc, vectorToStore, indices);
725 },
726 /*reduceFn=*/
727 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
728 toStore.push_back(v);
729 });
730 return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
731}
732
733static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
736 ArrayRef<int64_t> res) {
739 SmallVector<int64_t> vres(res);
740 return std::make_tuple(vlhs, vrhs, vres);
741}
742
743FailureOr<MmaSyncBuilder::MmaSyncInfo>
744MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
745 TypeRange elementalTypes) {
746 // TODO: Tablegen all this.
747 Type f16 = b.getF16Type();
748 Type f32 = b.getF32Type();
749 if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
750 elementalTypes == TypeRange{f32, f32, f32}) {
751 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
752 &MmaSyncBuilder::m16n8k4tf32Rhs,
753 &MmaSyncBuilder::m16n8k4tf32Res),
754 makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
755 SmallVector<int64_t>{opShape},
756 /*tf32Enabled=*/true};
757 }
758 // This is the version with f16 accumulation.
759 // TODO: version with f32 accumulation.
760 if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
761 elementalTypes == TypeRange{f16, f16, f16}) {
762 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
763 &MmaSyncBuilder::m16n8k16f16Rhs,
764 &MmaSyncBuilder::m16n8k16f16Res),
765 makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
766 SmallVector<int64_t>{opShape},
767 /*tf32Enabled=*/false};
768 }
769 return failure();
770}
771
772FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
773 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
774 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
775 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
776 assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
777 "expected lhs to be a 2D memref");
778 assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
779 "expected rhs to be a 2D memref");
780 assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
781 "expected res to be a 2D memref");
782
783 int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
784 int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
785 int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
786 Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
787 Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
788 Type resType = getElementTypeOrSelf(resMemRef.getType());
789
790 FailureOr<MmaSyncInfo> maybeInfo =
791 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
792 if (failed(maybeInfo))
793 return failure();
794
795 MmaSyncInfo info = *maybeInfo;
796 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
797 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
798 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
799 lhsIndexFn, lhsShape);
800 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
801 rhsIndexFn, rhsShape);
802 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
803 resIndexFn, resShape);
804 res =
805 MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled);
806 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
807 resShape);
808 return res.getDefiningOp();
809}
810
811DiagnosedSilenceableFailure RewriteMatmulAsMmaSyncOp::applyToOne(
812 TransformRewriter &rewriter, LinalgOp linalgOp,
813 ApplyToEachResultList &results, TransformState &state) {
814 bool fail = true;
815 // TODO: more robust detection of matmulOp, with transposes etc.
816 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
817 // Check to not let go the matmul with extended semantic, through this
818 // transform.
819 if (linalgOp.hasUserDefinedMaps()) {
820 return emitSilenceableError()
821 << "only matmul ops with non-extended semantics are supported";
822 }
823 Location loc = linalgOp.getLoc();
824 // TODO: more robust computation of laneId, for now assume a single warp.
825 Value laneId = gpu::ThreadIdOp::create(
826 rewriter, loc, rewriter.getIndexType(), gpu::Dimension::x);
827 if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
828 fail = false;
829 }
830
831 if (fail) {
832 DiagnosedSilenceableFailure diag = emitSilenceableError()
833 << "unsupported target op: " << linalgOp;
834 diag.attachNote(linalgOp->getLoc()) << "target op";
835 return diag;
836 }
837
838 rewriter.eraseOp(linalgOp);
840}
841
842//===----------------------------------------------------------------------===//
843// Hopper builders.
844//===----------------------------------------------------------------------===//
845
846/// Helper to create the base Hopper-specific operations that are reused in
847/// various other places.
851
854
855 /// Create tma descriptor op to initiate transfer from global to shared
856 /// memory. This must be done before the launch op, on the host.
859 gpu::LaunchOp launchOp);
860
861 /// Build a tma load from global memory to shared memory using `barrier` to
862 /// synchronize. Return the number of bytes that will be transferred.
864 TypedValue<MemRefType> sharedMemref,
869
870 /// If threadIdx.x == 0 does TMA request + wait, else just wait.
871 /// Return the operation that performs the transfer on thread0.
872 // TODO: In the future, don't hardcode to thread 0 but elect a leader.
875 ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
877
879
882};
883
886 ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
890 Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
891 Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
892 tidx, zero);
893 // clang-format off
894 scf::IfOp::create(rewriter,
895 /*location=*/loc,
896 /*conditional=*/cond,
897 /*thenBuilder=*/
898 [&](OpBuilder &lb, Location loc) {
900 sizes.reserve(globalDescriptors.size());
901 for (auto [desc, shmem] : llvm::zip_equal(
902 globalDescriptors, sharedMemBuffers)) {
903 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
904 sizes.push_back(sz);
905 }
906 // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
907 // This may or may not have perf implications.
908 buildBarrierArriveTx(barrier, sizes);
909 scf::YieldOp::create(rewriter, loc);
910 },
911 /*elseBuilder=*/
912 [&](OpBuilder &lb, Location loc) {
913 // TODO: is this for no-thread divergence?
914 // Should we just yield the size and hoist?
915 buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
916 scf::YieldOp::create(rewriter, loc);
917 });
918 // clang-format on
919 return loadOps;
920}
921
923 return gpu::AddressSpaceAttr::get(
924 b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
925 // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
926}
927
930 auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
931 Value barrier = MBarrierCreateOp::create(
932 rewriter, loc,
933 MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
935 nvgpu::MBarrierInitOp::create(
936 rewriter, loc, barrier,
937 getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero,
938 Value());
939 gpu::BarrierOp::create(rewriter, loc);
940 return cast<TypedValue<MBarrierGroupType>>(barrier);
941}
942
945 gpu::LaunchOp launchOp) {
947 rewriter.setInsertionPoint(launchOp);
948 Value unrankedMemRef = memref::CastOp::create(
949 rewriter, loc,
950 UnrankedMemRefType::get(memref.getType().getElementType(),
951 memref.getType().getMemorySpace()),
952 memref);
953 SmallVector<OpFoldResult> mixedSizes =
955 SmallVector<Value> sizes =
957
958 auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
959 Value desc = TmaCreateDescriptorOp::create(
960 rewriter, loc,
961 TensorMapDescriptorType::get(rewriter.getContext(),
962 MemRefType::Builder(memref.getType())
963 .setMemorySpace(sharedMemorySpace),
964 TensorMapSwizzleKind::SWIZZLE_NONE,
965 TensorMapL2PromoKind::L2PROMO_NONE,
966 TensorMapOOBKind::OOB_ZERO,
967 TensorMapInterleaveKind::INTERLEAVE_NONE),
968 unrankedMemRef, sizes);
969 return cast<TypedValue<TensorMapDescriptorType>>(desc);
970}
971
974 TypedValue<MemRefType> sharedMemref,
977 MLIRContext *ctx = rewriter.getContext();
979 Operation *loadOp =
980 TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc,
981 ValueRange{zero, zero}, zero, Value(), Value());
982 loadOps.push_back(loadOp);
983 auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
984 SmallVector<AffineExpr> symbols(mixedSizes.size());
986 AffineExpr prodExprInBytes =
987 computeProduct(ctx, symbols) *
988 (sharedMemref.getType().getElementTypeBitWidth() / 8);
990 prodExprInBytes, mixedSizes);
991 return res;
992}
993
995 ArrayRef<OpFoldResult> mixedSizes) {
996 assert(!mixedSizes.empty() && "expecte non-empty sizes");
997 MLIRContext *ctx = rewriter.getContext();
998 SmallVector<AffineExpr> symbols(mixedSizes.size());
1000 AffineExpr sumExpr = computeSum(ctx, symbols);
1001 OpFoldResult size =
1002 affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
1005 nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero,
1006 Value());
1007}
1008
1010 Type i1 = rewriter.getI1Type();
1011 Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
1012 // 10M is an arbitrary, not too small or too big number to specify the number
1013 // of ticks before retry.
1014 // TODO: hoist this in a default dialect constant.
1015 Value ticksBeforeRetry =
1018 nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity,
1019 ticksBeforeRetry, zero);
1020}
1021
1022//===----------------------------------------------------------------------===//
1023// RewriteCopyAsTmaOp
1024//===----------------------------------------------------------------------===//
1025
1026/// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1033
1035 MLIRContext *ctx = rewriter.getContext();
1036 if (copyOps.empty())
1037 return SmallVector<Operation *>();
1038
1039 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1040 assert(launchOp && "expected launch op");
1041
1042 // 1. Init a barrier object in shared memory.
1044 rewriter.setInsertionPoint(copyOps.front());
1045 AffineExpr bx, by, bz;
1046 bindSymbols(ctx, bx, by, bz);
1047 AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
1049 rewriter, loc, prod,
1050 ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1051 launchOp.getBlockSizeZ()});
1052
1055
1058 for (Operation *op : copyOps) {
1059 auto copyOp = cast<linalg::CopyOp>(op);
1060 auto inMemRef =
1061 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1062 assert(inMemRef.getType().getRank() == 2 &&
1063 "expected in to be a 2D memref");
1064
1065 // 2. Build global memory descriptor.
1067 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1068 globalDescs.push_back(globalDesc);
1069
1070 // 3. Shared memory and descriptor for the tmp array.
1071 auto shmem =
1072 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1073 shmems.push_back(shmem);
1074 }
1075
1076 // 4. Load in from global memory to shared memory using tma.
1078 rewriter.setInsertionPoint(copyOps.front());
1079 SmallVector<Operation *> results =
1080 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1081
1082 // 5. Spin-loop until data is ready.
1083 buildTryWaitParity(barrier);
1084
1085 // 6. Erase the ops that have now been rewritten.
1086 for (Operation *op : copyOps)
1087 rewriter.eraseOp(op);
1088
1089 return results;
1090}
1091
1093RewriteCopyAsTmaOp::apply(TransformRewriter &rewriter,
1094 TransformResults &results, TransformState &state) {
1095 auto payloadOps = state.getPayloadOps(getTarget());
1096 gpu::LaunchOp commonLaunchOp;
1097 Operation *firstOp, *failingOp;
1098 if (llvm::any_of(payloadOps, [&](Operation *op) {
1099 if (!commonLaunchOp) {
1100 commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1101 firstOp = op;
1102 }
1103 auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1104 commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1105 !isa<linalg::CopyOp>(op);
1106 if (fail)
1107 failingOp = op;
1108 return fail;
1109 })) {
1111 emitSilenceableError()
1112 << "target ops must be linalg::CopyOp nested under a common "
1113 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1114 "be created on the host.\nBut got: "
1115 << *firstOp << "\nand " << *failingOp;
1116 return diag;
1117 }
1118
1119 // TODO: more robust detection of copy, with transposes etc.
1120 CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1121
1123}
1124
1125//===----------------------------------------------------------------------===//
1126// Transform op registration
1127//===----------------------------------------------------------------------===//
1128
1129namespace {
1130class NVGPUTransformDialectExtension
1131 : public TransformDialectExtension<NVGPUTransformDialectExtension> {
1132public:
1133 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
1134
1135 NVGPUTransformDialectExtension() {
1136 declareGeneratedDialect<arith::ArithDialect>();
1137 declareGeneratedDialect<affine::AffineDialect>();
1138 declareGeneratedDialect<NVGPUDialect>();
1139 declareGeneratedDialect<NVVM::NVVMDialect>();
1140 declareGeneratedDialect<vector::VectorDialect>();
1141 registerTransformOps<
1142#define GET_OP_LIST
1143#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1144 >();
1145 }
1146};
1147} // namespace
1148
1149#define GET_OP_CLASSES
1150#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1151
1153 registry.addExtensions<NVGPUTransformDialectExtension>();
1154}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)
static bool hasDefaultMemorySpace(BaseMemRefType type)
Returns true if the given type has the default memory space.
static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)
Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...
static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)
Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...
static bool isStoreToShared(Operation *op, Value v)
Returns true if the operation is storing the given value into shared memory.
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)
Helper functions to create customizable load and stores operations.
static bool hasSharedMemorySpace(BaseMemRefType type)
Returns true if the given type has the shared (workgroup) memory space.
static bool isLoadFromGlobalStoredToShared(Operation *op)
Returns true if the operation is a load from the default memory space the result of which is only sto...
static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)
static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned > > &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)
Hook for the loop pipeliner that populates ops with the stage information as follows:
static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)
Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...
static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)
Hook for the loop pipeliner.
static Value getValueLoadedFromGlobal(Operation *op)
Returns the value produced by a load from the default memory space.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
IndexType getIndexType()
Definition Builders.cpp:51
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry &registry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition AffineExpr.h:330
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
void buildBarrierArriveTx(TypedValue< MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
OpFoldResult buildTmaAsyncLoad(TypedValue< TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
RewriterBase & rewriter
TypedValue< TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
void buildTryWaitParity(TypedValue< MBarrierGroupType > barrier)
TypedValue< MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< TensorMapDescriptorType > > globalDescriptors, ArrayRef< TypedValue< MemRefType > > sharedMemBuffers, TypedValue< MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
AffineExpr col() const
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
AffineExpr row() const
Options to dictate how loops should be pipelined.
Definition Transforms.h:129