MLIR 22.0.0git
NVGPUTransformOps.cpp
Go to the documentation of this file.
1//===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
29#include "mlir/IR/AffineExpr.h"
31#include "mlir/IR/Value.h"
32#include "llvm/ADT/ArrayRef.h"
33
34using namespace mlir;
35using namespace mlir::linalg;
36using namespace mlir::nvgpu;
37using namespace mlir::NVVM;
38using namespace mlir::transform;
39
40#define DEBUG_TYPE "nvgpu-transforms"
41
42//===----------------------------------------------------------------------===//
43// Apply...ConversionPatternsOp
44//===----------------------------------------------------------------------===//
45
46void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
47 TypeConverter &typeConverter, RewritePatternSet &patterns) {
48 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
49 /// device-side async tokens cannot be materialized in nvvm. We just
50 /// convert them to a dummy i32 type in order to easily drop them during
51 /// conversion.
53 llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) -> Type {
54 return llvmTypeConverter.convertType(
55 IntegerType::get(type.getContext(), 32));
56 });
57 llvmTypeConverter.addConversion([&](MBarrierTokenType type) -> Type {
58 return llvmTypeConverter.convertType(
59 IntegerType::get(type.getContext(), 64));
60 });
61 llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) -> Type {
62 Type elemType = type.getFragmented().getElementType();
63 int64_t sizeM = type.getFragmented().getDimSize(0);
64 int64_t sizeN = type.getFragmented().getDimSize(1);
65
66 unsigned numMembers;
67 if (elemType.isF32() || elemType.isInteger(32))
68 numMembers = sizeN / 2;
69 else if (elemType.isF16())
70 numMembers = sizeN / 4;
71 else
72 llvm_unreachable("unsupported type for warpgroup accumulator");
73
74 SmallVector<Type> innerStructBody;
75 for (unsigned i = 0; i < numMembers; i++)
76 innerStructBody.push_back(elemType);
77 auto innerStructType =
78 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
79
80 SmallVector<Type> structBody;
81 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
82 structBody.push_back(innerStructType);
83
84 auto convertedType =
85 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
86 return llvmTypeConverter.convertType(convertedType);
87 });
88 llvmTypeConverter.addConversion([&](MBarrierGroupType type) -> Type {
89 return llvmTypeConverter.convertType(
90 getMBarrierMemrefType(type.getContext(), type));
91 });
92 llvmTypeConverter.addConversion(
93 [&](WarpgroupMatrixDescriptorType type) -> Type {
94 return llvmTypeConverter.convertType(
95 IntegerType::get(type.getContext(), 64));
96 });
97 llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) -> Type {
98 return LLVM::LLVMPointerType::get(type.getContext());
99 });
101}
102
103LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
104 TypeConverterBuilderOpInterface builder) {
105 if (builder.getTypeConverterType() != "LLVMTypeConverter")
106 return emitOpError("expected LLVMTypeConverter");
107 return success();
108}
109
110//===---------------------------------------------------------------------===//
111// CreateAsyncGroupsOp
112//===---------------------------------------------------------------------===//
113
114void CreateAsyncGroupsOp::getEffects(
116 consumesHandle(getTargetMutable(), effects);
117 producesHandle(getOperation()->getOpResults(), effects);
118 modifiesPayload(effects);
119}
120
122CreateAsyncGroupsOp::applyToOne(TransformRewriter &rewriter, Operation *target,
123 ApplyToEachResultList &results,
124 TransformState &state) {
125 createAsyncGroups(rewriter, target, getBypassL1());
126 results.push_back(target);
128}
129
130//===----------------------------------------------------------------------===//
131// PipelineSharedMemoryCopiesOp
132//===----------------------------------------------------------------------===//
133
134/// Returns true if the given type has the default memory space.
136 return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
137}
138
139/// Returns true if the given type has the shared (workgroup) memory space.
141 auto space =
142 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
143 return space &&
144 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
145}
146
147/// Returns the value produced by a load from the default memory space. Returns
148/// null if the operation is not such a load.
150 // TODO: consider an interface or leveraging the memory effects interface.
151 auto load = dyn_cast<vector::TransferReadOp>(op);
152 if (!load)
153 return nullptr;
154
155 auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
156 if (!loadType || !hasDefaultMemorySpace(loadType))
157 return nullptr;
158 return load;
159}
160
161/// Returns true if the operation is storing the given value into shared memory.
162static bool isStoreToShared(Operation *op, Value v) {
163 // TOD: consider an interface or leveraging the memory effects interface.
164 auto store = dyn_cast<vector::TransferWriteOp>(op);
165 if (!store || store.getVector() != v)
166 return false;
167
168 auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
169 return storeType || hasSharedMemorySpace(storeType);
170}
171
172/// Returns true if the operation is a load from the default memory space the
173/// result of which is only stored into the shared memory space.
175 Value loaded = getValueLoadedFromGlobal(op);
176 if (!loaded || !loaded.hasOneUse())
177 return false;
178
179 return isStoreToShared(*loaded.getUsers().begin(), loaded);
180}
181
182/// Populate `ops` with the set of operations that belong to the stage 0 of the
183/// pipelined version of the given loop when pipelining copies to shared memory.
184/// Specifically, this collects:
185///
186/// 1. all loads from global memory, both sync and async;
187/// 2. the barriers for async loads.
188///
189/// In particular, barriers are omitted if they do not dominate at least one
190/// async load for which there is not yet a barrier.
191static LogicalResult
194
196 for (Operation &op : *forOp.getBody()) {
197 // Bail on nested ops for now.
198 if (op.getNumRegions() > 0)
199 return failure();
200
201 if (isa<gpu::BarrierOp>(op)) {
202 barriers.insert(&op);
203 continue;
204 }
205
206 if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) {
207 ops.insert(&op);
208 ops.insert(std::make_move_iterator(barriers.begin()),
209 std::make_move_iterator(barriers.end()));
210 assert(barriers.empty() &&
211 "expected to have moved the barriers into another set");
212 continue;
213 }
214
216 ops.insert(&op);
217 continue;
218 }
219 }
220
221 return success();
222}
223
224/// Hook for the loop pipeliner that sets the "num groups in flight" attribute
225/// of async wait operations corresponding to pipelined shared memory copies.
226// TODO: this currently assumes that there are no groups that could be in flight
227// in the existing code.
228static void
231 unsigned iteration, unsigned depth) {
232 // Based on the order of copies within the loop we need to set the number
233 // of copies in flight, unless it is already set.
234 auto waitOp = dyn_cast<DeviceAsyncWaitOp>(op);
235 if (!waitOp || waitOp.getNumGroups())
236 return;
237
238 int numGroupInFlight = 0;
241 numGroupInFlight = depth - 1;
242 } else {
243 // By construction there should be no wait op in the prologue as all the
244 // wait should be in the last stage.
246 // Based on the schedule we pick we know how many groups are in flight for
247 // each iteration of the epilogue.
248 numGroupInFlight = depth - 1 - iteration;
249 }
250 waitOp.setNumGroups(numGroupInFlight);
251}
252
253/// Hook for the loop pipeliner that populates `ops` with the stage information
254/// as follows:
255///
256/// - operations in `stage0Ops` (typically loads from global memory and
257/// related barriers) are at stage 0;
258/// - operations in the backward slice of any stage0Ops are all at stage 0;
259/// - other operations are at stage `depth`;
260/// - the internal order of the pipelined loop has ops at stage `depth` first,
261/// then those at stage 0, with relative order within each group preserved.
262///
264 scf::ForOp forOp,
265 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
266 unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
267 SetVector<Operation *> dependencies;
269 return visited->getBlock() == forOp.getBody();
270 });
271 options.inclusive = true;
272 for (Operation &op : forOp.getBody()->getOperations()) {
273 if (stage0Ops.contains(&op)) {
274 LogicalResult result = getBackwardSlice(&op, &dependencies, options);
275 assert(result.succeeded() && "expected a backward slice");
276 (void)result;
277 }
278 }
279
280 for (Operation &op : forOp.getBody()->getOperations()) {
281 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
282 opsWithPipelineStages.emplace_back(&op, depth);
283 }
284 for (Operation &op : forOp.getBody()->getOperations()) {
285 if (dependencies.contains(&op))
286 opsWithPipelineStages.emplace_back(&op, 0);
287 }
288}
289
290/// Hook for the loop pipeliner. Replaces op with a predicated version and
291/// returns the resulting operation. Returns the original op if the predication
292/// isn't necessary for the given op. Returns null if predication is needed but
293/// not supported.
295 Operation *op, Value predicate) {
296 // Some operations may be fine to execute "speculatively" more times than the
297 // original number of iterations, in particular side-effect free operations
298 // and barriers, even if they cannot be predicated.
299 if (isMemoryEffectFree(op) ||
300 isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) {
301 return op;
302 }
303
304 // Otherwise, only async copies can currently be predicated.
305 auto asyncCopyOp = dyn_cast<DeviceAsyncCopyOp>(op);
306 if (!asyncCopyOp)
307 return nullptr;
308
309 // Create srcElement Value based on `predicate`. The next lines generate
310 // the following code:
311 //
312 // srcElement = (pred) ? prevSrcElements : 0;
313 //
314 Location loc = asyncCopyOp->getLoc();
315 Value dstElements = arith::ConstantOp::create(
316 rewriter, loc, asyncCopyOp.getDstElementsAttr());
317 Value originalSrcElement =
318 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
319 Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0);
320 auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
321 originalSrcElement, c0Index);
322 auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create(
323 rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
324 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
325 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
326 UnitAttr());
327 rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
328 return asyncCopyZeroFillOp;
329}
330
331/// Applies loop pipelining with the given depth to the given loop so that
332/// copies into the shared memory are pipelined. Doesn't affect other loops.
333/// Returns a pair containing the error state and the pipelined op, the latter
334/// being null in case of any failure. The error state contains a definite error
335/// if the IR has been modified and a silenceable error otherwise.
336static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
337pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
338 bool epiloguePeeling) {
340 if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
341 return std::make_tuple(
342 emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
343 scf::ForOp());
344 }
345 if (stage0Ops.empty()) {
346 return std::make_tuple(
347 emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
348 }
349
351 unsigned maxDepth = depth;
352 auto setAnnotation = [&](Operation *op,
354 unsigned iteration) {
355 return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
356 };
357 options.getScheduleFn =
358 [&](scf::ForOp schedulingFor,
359 std::vector<std::pair<Operation *, unsigned>> &ops) {
360 if (schedulingFor != forOp)
361 return;
362 return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
363 };
364 options.annotateFn = setAnnotation;
365 if (!epiloguePeeling) {
366 options.peelEpilogue = false;
368 }
369
370 OpBuilder::InsertionGuard guard(rewriter);
371 rewriter.setInsertionPoint(forOp);
372 bool modifiedIR;
373 FailureOr<scf::ForOp> maybePipelined =
374 pipelineForLoop(rewriter, forOp, options, &modifiedIR);
375 if (succeeded(maybePipelined)) {
376 return std::make_tuple(DiagnosedSilenceableFailure::success(),
377 *maybePipelined);
378 }
379 return std::make_tuple(
380 modifiedIR
382 : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
383 scf::ForOp());
384}
385
386DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
387 TransformRewriter &rewriter, scf::ForOp forOp,
388 ApplyToEachResultList &results, TransformState &state) {
389 auto [diag, pipelined] = pipelineForSharedCopies(
390 rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
391 if (diag.succeeded()) {
392 results.push_back(pipelined);
394 }
395 if (diag.isDefiniteFailure()) {
396 auto diag = emitDefiniteFailure("irreversible pipelining failure");
397 if (!getPeelEpilogue()) {
398 diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
399 diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
400 }
401 return diag;
402 }
403
404 return std::move(diag);
405}
406
407//===----------------------------------------------------------------------===//
408// RewriteMatmulAsMmaSyncOp
409//===----------------------------------------------------------------------===//
410
411/// Helper struct to encode a pair of row/column indexings in the form of
412/// affine expressions.
413struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
416
417 AffineExpr row() const { return first; };
418 AffineExpr col() const { return second; };
419
420 void print(llvm::raw_ostream &os) const {
421 os << "- indexing: " << first << ", " << second;
422 }
423};
424
425/// Helper struct to provide a simple mapping from matmul operations to the
426/// corresponding mma.sync operation. This is constrained to the case where the
427/// matmul matches the mma.sync operation 1-1.
430 : b(b), loc(loc), laneId(laneId) {}
431
433 std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
434
435 /// Create the mma.sync operation corresponding to `linalgOp` along with all
436 /// the supporting load/store and vector operations.
437 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
438
439private:
440 struct MmaSyncInfo {
441 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
442 std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
443 vectorShapes;
444 SmallVector<int64_t> mmaShape;
445 bool tf32Enabled;
446 };
447
448 /// Return the specific index calculator for the given `linalgOp` or failure
449 /// if the op is not supported. This is the toplevel switch that should just
450 /// be Tablegen'd in the future.
451 FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
452 TypeRange elementalTypes);
453
454 //===--------------------------------------------------------------------===//
455 // Instruction-specific row, column indexing expression builders.
456 // These should all be declaratively specified via Tablegen in the future.
457 // The Tablegen specification should be as straightforward as possible to
458 // only model the existing size and type combinations.
459 //===--------------------------------------------------------------------===//
460 //
461 // TODO: Tablegen all this.
462 //===--------------------------------------------------------------------===//
463 // m16n8k4 tf32 case.
464 //===--------------------------------------------------------------------===//
465 /// From the NVIDIA doc:
466 /// groupID = %laneid >> 2
467 /// threadIDInGroup = %laneid % 4
468 /// row = groupID for a0
469 /// groupID + 8 for a1
470 /// col = threadIDInGroup
471 static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
472 auto dim = getAffineDimExpr(0, ctx);
473 AffineExpr groupID = dim.floorDiv(4);
474 AffineExpr threadIDInGroup = dim % 4;
475 return {RowColIndexing{groupID, threadIDInGroup},
476 RowColIndexing{groupID + 8, threadIDInGroup}};
477 }
478
479 /// From the NVIDIA doc:
480 /// groupID = %laneid >> 2
481 /// threadIDInGroup = %laneid % 4
482 /// row = threadIDInGroup
483 /// col = groupID
484 static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
485 auto dim = getAffineDimExpr(0, ctx);
486 AffineExpr groupID = dim.floorDiv(4);
487 AffineExpr threadIDInGroup = dim % 4;
488 return {RowColIndexing{threadIDInGroup, groupID}};
489 }
490
491 /// From the NVIDIA doc:
492 /// groupID = %laneid >> 2
493 /// threadIDInGroup = %laneid % 4
494 /// row = groupID for c0 and c1
495 /// groupID + 8 for c2 and c3
496 /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
497 static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
498 auto dim = getAffineDimExpr(0, ctx);
499 AffineExpr groupID = dim.floorDiv(4);
500 AffineExpr threadIDInGroup = dim % 4;
501 return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
502 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
503 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
504 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
505 }
506
507 //===--------------------------------------------------------------------===//
508 // m16n8k16 f16 case.
509 //===--------------------------------------------------------------------===//
510 /// From the NVIDIA doc:
511 /// groupID = %laneid >> 2
512 /// threadIDInGroup = %laneid % 4
513 ///
514 /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
515 /// groupID + 8 Otherwise
516 ///
517 /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
518 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
519 static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
520 auto dim = getAffineDimExpr(0, ctx);
521 AffineExpr groupID = dim.floorDiv(4);
522 AffineExpr threadIDInGroup = dim % 4;
523 // clang-format off
524 return {
525 RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
526 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
527 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
528 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
529 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
530 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
531 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
532 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
533 };
534 // clang-format on
535 }
536
537 /// From the NVIDIA doc:
538 /// groupID = %laneid >> 2
539 /// threadIDInGroup = %laneid % 4
540 ///
541 /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
542 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
543 ///
544 /// col = groupID
545 static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
546 auto dim = getAffineDimExpr(0, ctx);
547 AffineExpr groupID = dim.floorDiv(4);
548 AffineExpr threadIDInGroup = dim % 4;
549 // clang-format off
550 return {
551 RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
552 RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
553 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
554 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
555 };
556 // clang-format on
557 }
558
559 /// From the NVIDIA doc:
560 /// groupID = %laneid >> 2
561 /// threadIDInGroup = %laneid % 4
562 ///
563 /// row = groupID for ci where i < 2
564 /// groupID + 8 for ci where i >= 2
565 ///
566 /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
567 static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
568 auto dim = getAffineDimExpr(0, ctx);
569 AffineExpr groupID = dim.floorDiv(4);
570 AffineExpr threadIDInGroup = dim % 4;
571 // clang-format off
572 return {
573 RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
574 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
575 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
576 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
577 };
578 // clang-format on
579 }
580
581 //===--------------------------------------------------------------------===//
582 /// Helper functions to create customizable load and stores operations. The
583 /// specific shapes of each MMA instruction are passed via the
584 /// IndexCalculator callback.
585 //===--------------------------------------------------------------------===//
586 /// Build a list of memref.load operations indexed at `(row, col)` indices
587 /// that make sense for a particular MMA instruction and specified via the
588 /// IndexCalculator callback.
589 SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
590 OpFoldResult laneId, Value memref,
591 const IndexCalculator &indexFn);
592
593 /// Perform a distributed load of a vector operand of `vectorShape` for a
594 /// particular MMA instruction whose `(row, col)` indices are specified via
595 /// the IndexCalculator callback. Each `laneId` loads the subportion of the
596 /// data that makes sense for the particular MMA operation.
597 /// The `vectorShape` matches existing NVGPU dialect op specification but
598 /// could also be flattened in the future if needed for simplification.
599 Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
600 OpFoldResult laneId, Value memref,
601 IndexCalculator indexFn,
602 ArrayRef<int64_t> vectorShape);
603
604 /// Build a list of memref.store operations indexed at `(row, col)` indices
605 /// that make sense for a particular MMA instruction and specified via the
606 /// IndexCalculator callback.
607 SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
608 ValueRange toStore,
609 OpFoldResult laneId, Value memref,
610 const IndexCalculator &indexFn);
611
612 /// Perform a distributed store of a vector operand of `vectorShape` for a
613 /// particular MMA instruction whose `(row, col)` indices are specified via
614 /// the IndexCalculator callback. Each `laneId` loads the subportion of the
615 /// data that makes sense for the particular MMA operation.
616 /// The `vectorShape` matches existing NVGPU dialect op specification but
617 /// could also be flattened in the future if needed for simplification.
618 SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
619 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
620 Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
621
622 OpBuilder &b;
623 Location loc;
624 OpFoldResult laneId;
625};
626
627//===--------------------------------------------------------------------===//
628/// Helper functions to create customizable load and stores operations. The
629/// specific shapes of each MMA instruction are passed via the
630/// IndexCalculator callback.
631//===--------------------------------------------------------------------===//
632
633template <typename ApplyFn, typename ReduceFn>
634static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
635 ReduceFn reduceFn) {
636 VectorType vectorType = cast<VectorType>(vector.getType());
637 auto vectorShape = vectorType.getShape();
638 auto strides = computeStrides(vectorShape);
639 for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
640 auto indices = delinearize(idx, strides);
641 reduceFn(applyFn(vector, idx, indices), idx, indices);
642 }
643}
644
646MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
647 OpFoldResult laneId, Value memref,
648 const IndexCalculator &indexFn) {
649 auto aff = [&](AffineExpr e) {
650 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
651 };
652 SmallVector<Value> res;
653 SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
654 for (auto indexing : indexings) {
655 Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
656 Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
657 auto load = memref::LoadOp::create(b, loc, memref, ValueRange{row, col});
658 res.push_back(load);
659 }
660 return res;
661}
662
663Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
665 IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
666 auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn);
667
668 Type elementType = getElementTypeOrSelf(memref.getType());
669 auto vt = VectorType::get(vectorShape, elementType);
670 Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
672 res,
673 /*applyFn=*/
674 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
675 return loads[linearIdx];
676 },
677 /*reduceFn=*/
678 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
679 res = vector::InsertOp::create(b, loc, v, res, indices);
680 });
681
682 return res;
683}
684
685SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
686 OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
687 Value memref, const IndexCalculator &indexFn) {
688 auto aff = [&](AffineExpr e) {
689 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
690 };
691 SmallVector<Operation *> res;
692 for (auto [indexing, val] :
693 llvm::zip_equal(indexFn(b.getContext()), toStore)) {
694 Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
695 Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
696 Operation *store =
697 memref::StoreOp::create(b, loc, val, memref, ValueRange{row, col});
698 res.push_back(store);
699 }
700 return res;
701}
702
703SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
704 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
705 Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
706 SmallVector<Value> toStore;
707 toStore.reserve(32);
709 vectorToStore,
710 /*applyFn=*/
711 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
712 return vector::ExtractOp::create(b, loc, vectorToStore, indices);
713 },
714 /*reduceFn=*/
715 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
716 toStore.push_back(v);
717 });
718 return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn);
719}
720
721static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
724 ArrayRef<int64_t> res) {
727 SmallVector<int64_t> vres(res);
728 return std::make_tuple(vlhs, vrhs, vres);
729}
730
731FailureOr<MmaSyncBuilder::MmaSyncInfo>
732MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
733 TypeRange elementalTypes) {
734 // TODO: Tablegen all this.
735 Type f16 = b.getF16Type();
736 Type f32 = b.getF32Type();
737 if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
738 elementalTypes == TypeRange{f32, f32, f32}) {
739 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
740 &MmaSyncBuilder::m16n8k4tf32Rhs,
741 &MmaSyncBuilder::m16n8k4tf32Res),
742 makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
743 SmallVector<int64_t>{opShape},
744 /*tf32Enabled=*/true};
745 }
746 // This is the version with f16 accumulation.
747 // TODO: version with f32 accumulation.
748 if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
749 elementalTypes == TypeRange{f16, f16, f16}) {
750 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
751 &MmaSyncBuilder::m16n8k16f16Rhs,
752 &MmaSyncBuilder::m16n8k16f16Res),
753 makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
754 SmallVector<int64_t>{opShape},
755 /*tf32Enabled=*/false};
756 }
757 return failure();
758}
759
760FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
761 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
762 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
763 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
764 assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
765 "expected lhs to be a 2D memref");
766 assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
767 "expected rhs to be a 2D memref");
768 assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
769 "expected res to be a 2D memref");
770
771 int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
772 int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
773 int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
774 Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
775 Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
776 Type resType = getElementTypeOrSelf(resMemRef.getType());
777
778 FailureOr<MmaSyncInfo> maybeInfo =
779 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
780 if (failed(maybeInfo))
781 return failure();
782
783 const MmaSyncInfo &info = *maybeInfo;
784 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
785 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
786 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
787 lhsIndexFn, lhsShape);
788 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
789 rhsIndexFn, rhsShape);
790 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
791 resIndexFn, resShape);
792 res =
793 MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled);
794 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
795 resShape);
796 return res.getDefiningOp();
797}
798
799DiagnosedSilenceableFailure RewriteMatmulAsMmaSyncOp::applyToOne(
800 TransformRewriter &rewriter, LinalgOp linalgOp,
801 ApplyToEachResultList &results, TransformState &state) {
802 bool fail = true;
803 // TODO: more robust detection of matmulOp, with transposes etc.
804 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
805 // Check to not let go the matmul with extended semantic, through this
806 // transform.
807 if (linalgOp.hasUserDefinedMaps()) {
808 return emitSilenceableError()
809 << "only matmul ops with non-extended semantics are supported";
810 }
811 Location loc = linalgOp.getLoc();
812 // TODO: more robust computation of laneId, for now assume a single warp.
813 Value laneId = gpu::ThreadIdOp::create(
814 rewriter, loc, rewriter.getIndexType(), gpu::Dimension::x);
815 if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
816 fail = false;
817 }
818
819 if (fail) {
820 DiagnosedSilenceableFailure diag = emitSilenceableError()
821 << "unsupported target op: " << linalgOp;
822 diag.attachNote(linalgOp->getLoc()) << "target op";
823 return diag;
824 }
825
826 rewriter.eraseOp(linalgOp);
828}
829
830//===----------------------------------------------------------------------===//
831// Hopper builders.
832//===----------------------------------------------------------------------===//
833
834/// Helper to create the base Hopper-specific operations that are reused in
835/// various other places.
839
842
843 /// Create tma descriptor op to initiate transfer from global to shared
844 /// memory. This must be done before the launch op, on the host.
847 gpu::LaunchOp launchOp);
848
849 /// Build a tma load from global memory to shared memory using `barrier` to
850 /// synchronize. Return the number of bytes that will be transferred.
852 TypedValue<MemRefType> sharedMemref,
857
858 /// If threadIdx.x == 0 does TMA request + wait, else just wait.
859 /// Return the operation that performs the transfer on thread0.
860 // TODO: In the future, don't hardcode to thread 0 but elect a leader.
863 ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
865
867
870};
871
874 ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
878 Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
879 Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
880 tidx, zero);
881 // clang-format off
882 scf::IfOp::create(rewriter,
883 /*location=*/loc,
884 /*conditional=*/cond,
885 /*thenBuilder=*/
886 [&](OpBuilder &lb, Location loc) {
888 sizes.reserve(globalDescriptors.size());
889 for (auto [desc, shmem] : llvm::zip_equal(
890 globalDescriptors, sharedMemBuffers)) {
891 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
892 sizes.push_back(sz);
893 }
894 // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
895 // This may or may not have perf implications.
896 buildBarrierArriveTx(barrier, sizes);
897 scf::YieldOp::create(rewriter, loc);
898 },
899 /*elseBuilder=*/
900 [&](OpBuilder &lb, Location loc) {
901 // TODO: is this for no-thread divergence?
902 // Should we just yield the size and hoist?
903 buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
904 scf::YieldOp::create(rewriter, loc);
905 });
906 // clang-format on
907 return loadOps;
908}
909
911 return gpu::AddressSpaceAttr::get(
912 b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
913 // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
914}
915
918 auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
919 Value barrier = MBarrierCreateOp::create(
920 rewriter, loc,
921 MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
923 nvgpu::MBarrierInitOp::create(
924 rewriter, loc, barrier,
925 getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero,
926 Value());
927 gpu::BarrierOp::create(rewriter, loc);
928 return cast<TypedValue<MBarrierGroupType>>(barrier);
929}
930
933 gpu::LaunchOp launchOp) {
935 rewriter.setInsertionPoint(launchOp);
936 Value unrankedMemRef = memref::CastOp::create(
937 rewriter, loc,
938 UnrankedMemRefType::get(memref.getType().getElementType(),
939 memref.getType().getMemorySpace()),
940 memref);
941 SmallVector<OpFoldResult> mixedSizes =
943 SmallVector<Value> sizes =
945
946 auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
947 Value desc = TmaCreateDescriptorOp::create(
948 rewriter, loc,
949 TensorMapDescriptorType::get(rewriter.getContext(),
950 MemRefType::Builder(memref.getType())
951 .setMemorySpace(sharedMemorySpace),
952 TensorMapSwizzleKind::SWIZZLE_NONE,
953 TensorMapL2PromoKind::L2PROMO_NONE,
954 TensorMapOOBKind::OOB_ZERO,
955 TensorMapInterleaveKind::INTERLEAVE_NONE),
956 unrankedMemRef, sizes);
957 return cast<TypedValue<TensorMapDescriptorType>>(desc);
958}
959
962 TypedValue<MemRefType> sharedMemref,
965 MLIRContext *ctx = rewriter.getContext();
967 Operation *loadOp =
968 TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc,
969 ValueRange{zero, zero}, zero, Value(), Value());
970 loadOps.push_back(loadOp);
971 auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
972 SmallVector<AffineExpr> symbols(mixedSizes.size());
974 AffineExpr prodExprInBytes =
975 computeProduct(ctx, symbols) *
976 (sharedMemref.getType().getElementTypeBitWidth() / 8);
978 prodExprInBytes, mixedSizes);
979 return res;
980}
981
983 ArrayRef<OpFoldResult> mixedSizes) {
984 assert(!mixedSizes.empty() && "expecte non-empty sizes");
985 MLIRContext *ctx = rewriter.getContext();
986 SmallVector<AffineExpr> symbols(mixedSizes.size());
988 AffineExpr sumExpr = computeSum(ctx, symbols);
989 OpFoldResult size =
993 nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero,
994 Value());
995}
996
998 Type i1 = rewriter.getI1Type();
999 Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
1000 // 10M is an arbitrary, not too small or too big number to specify the number
1001 // of ticks before retry.
1002 // TODO: hoist this in a default dialect constant.
1003 Value ticksBeforeRetry =
1006 nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity,
1007 ticksBeforeRetry, zero);
1008}
1009
1010//===----------------------------------------------------------------------===//
1011// RewriteCopyAsTmaOp
1012//===----------------------------------------------------------------------===//
1013
1014/// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1021
1023 MLIRContext *ctx = rewriter.getContext();
1024 if (copyOps.empty())
1025 return SmallVector<Operation *>();
1026
1027 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1028 assert(launchOp && "expected launch op");
1029
1030 // 1. Init a barrier object in shared memory.
1032 rewriter.setInsertionPoint(copyOps.front());
1033 AffineExpr bx, by, bz;
1034 bindSymbols(ctx, bx, by, bz);
1035 AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
1037 rewriter, loc, prod,
1038 ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1039 launchOp.getBlockSizeZ()});
1040
1043
1046 for (Operation *op : copyOps) {
1047 auto copyOp = cast<linalg::CopyOp>(op);
1048 auto inMemRef =
1049 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1050 assert(inMemRef.getType().getRank() == 2 &&
1051 "expected in to be a 2D memref");
1052
1053 // 2. Build global memory descriptor.
1055 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1056 globalDescs.push_back(globalDesc);
1057
1058 // 3. Shared memory and descriptor for the tmp array.
1059 auto shmem =
1060 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1061 shmems.push_back(shmem);
1062 }
1063
1064 // 4. Load in from global memory to shared memory using tma.
1066 rewriter.setInsertionPoint(copyOps.front());
1067 SmallVector<Operation *> results =
1068 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1069
1070 // 5. Spin-loop until data is ready.
1071 buildTryWaitParity(barrier);
1072
1073 // 6. Erase the ops that have now been rewritten.
1074 for (Operation *op : copyOps)
1075 rewriter.eraseOp(op);
1076
1077 return results;
1078}
1079
1081RewriteCopyAsTmaOp::apply(TransformRewriter &rewriter,
1082 TransformResults &results, TransformState &state) {
1083 auto payloadOps = state.getPayloadOps(getTarget());
1084 gpu::LaunchOp commonLaunchOp;
1085 Operation *firstOp, *failingOp;
1086 if (llvm::any_of(payloadOps, [&](Operation *op) {
1087 if (!commonLaunchOp) {
1088 commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1089 firstOp = op;
1090 }
1091 auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1092 commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1093 !isa<linalg::CopyOp>(op);
1094 if (fail)
1095 failingOp = op;
1096 return fail;
1097 })) {
1099 emitSilenceableError()
1100 << "target ops must be linalg::CopyOp nested under a common "
1101 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1102 "be created on the host.\nBut got: "
1103 << *firstOp << "\nand " << *failingOp;
1104 return diag;
1105 }
1106
1107 // TODO: more robust detection of copy, with transposes etc.
1108 CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1109
1111}
1112
1113//===----------------------------------------------------------------------===//
1114// Transform op registration
1115//===----------------------------------------------------------------------===//
1116
1117namespace {
1118class NVGPUTransformDialectExtension
1119 : public TransformDialectExtension<NVGPUTransformDialectExtension> {
1120public:
1121 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
1122
1123 NVGPUTransformDialectExtension() {
1124 declareGeneratedDialect<arith::ArithDialect>();
1125 declareGeneratedDialect<affine::AffineDialect>();
1126 declareGeneratedDialect<NVGPUDialect>();
1127 declareGeneratedDialect<NVVM::NVVMDialect>();
1128 declareGeneratedDialect<vector::VectorDialect>();
1129 registerTransformOps<
1130#define GET_OP_LIST
1131#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1132 >();
1133 }
1134};
1135} // namespace
1136
1137#define GET_OP_CLASSES
1138#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1139
1141 registry.addExtensions<NVGPUTransformDialectExtension>();
1142}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)
static bool hasDefaultMemorySpace(BaseMemRefType type)
Returns true if the given type has the default memory space.
static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)
Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...
static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)
Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...
static bool isStoreToShared(Operation *op, Value v)
Returns true if the operation is storing the given value into shared memory.
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)
Helper functions to create customizable load and stores operations.
static bool hasSharedMemorySpace(BaseMemRefType type)
Returns true if the given type has the shared (workgroup) memory space.
static bool isLoadFromGlobalStoredToShared(Operation *op)
Returns true if the operation is a load from the default memory space the result of which is only sto...
static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)
static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned > > &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)
Hook for the loop pipeliner that populates ops with the stage information as follows:
static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)
Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...
static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)
Hook for the loop pipeliner.
static Value getValueLoadedFromGlobal(Operation *op)
Returns the value produced by a load from the default memory space.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
IndexType getIndexType()
Definition Builders.cpp:51
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry &registry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition AffineExpr.h:330
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
void buildBarrierArriveTx(TypedValue< MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
OpFoldResult buildTmaAsyncLoad(TypedValue< TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
RewriterBase & rewriter
TypedValue< TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
void buildTryWaitParity(TypedValue< MBarrierGroupType > barrier)
TypedValue< MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< TensorMapDescriptorType > > globalDescriptors, ArrayRef< TypedValue< MemRefType > > sharedMemBuffers, TypedValue< MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
AffineExpr col() const
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
AffineExpr row() const
Options to dictate how loops should be pipelined.
Definition Transforms.h:129