MLIR  20.0.0git
GPUTransformOps.cpp
Go to the documentation of this file.
1 //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/Builders.h"
32 #include "mlir/IR/IRMapping.h"
33 #include "mlir/IR/MLIRContext.h"
34 #include "mlir/IR/OpDefinition.h"
35 #include "mlir/IR/Visitors.h"
36 #include "mlir/Support/LLVM.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/ErrorHandling.h"
43 #include <type_traits>
44 
45 using namespace mlir;
46 using namespace mlir::gpu;
47 using namespace mlir::transform;
48 using namespace mlir::transform::gpu;
49 
50 #define DEBUG_TYPE "gpu-transforms"
51 #define DEBUG_TYPE_ALIAS "gpu-transforms-alias"
52 
53 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
54 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
55 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
56 
57 //===----------------------------------------------------------------------===//
58 // Apply...ConversionPatternsOp
59 //===----------------------------------------------------------------------===//
60 
61 void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
62  TypeConverter &typeConverter, RewritePatternSet &patterns) {
63  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
64  // NVVM uses alloca in the default address space to represent private
65  // memory allocations, so drop private annotations. NVVM uses address
66  // space 3 for shared memory. NVVM uses the default address space to
67  // represent global memory.
68  // Used in populateGpuToNVVMConversionPatternsso attaching here for now.
69  // TODO: We should have a single to_nvvm_type_converter.
71  llvmTypeConverter, [](AddressSpace space) -> unsigned {
72  switch (space) {
73  case AddressSpace::Global:
74  return static_cast<unsigned>(
76  case AddressSpace::Workgroup:
77  return static_cast<unsigned>(
79  case AddressSpace::Private:
80  return 0;
81  }
82  llvm_unreachable("unknown address space enum value");
83  return 0;
84  });
85  // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now.
86  // TODO: We should have a single to_nvvm_type_converter.
87  llvmTypeConverter.addConversion(
88  [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); });
89  populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns);
90 }
91 
92 LogicalResult
93 transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
94  transform::TypeConverterBuilderOpInterface builder) {
95  if (builder.getTypeConverterType() != "LLVMTypeConverter")
96  return emitOpError("expected LLVMTypeConverter");
97  return success();
98 }
99 
100 void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
101  TypeConverter &typeConverter, RewritePatternSet &patterns) {
102  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
103  populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns);
104 }
105 
106 LogicalResult
107 transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
108  transform::TypeConverterBuilderOpInterface builder) {
109  if (builder.getTypeConverterType() != "LLVMTypeConverter")
110  return emitOpError("expected LLVMTypeConverter");
111  return success();
112 }
113 
114 void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
115  populatePatterns(TypeConverter &typeConverter,
116  RewritePatternSet &patterns) {
117  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
118  populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns);
119 }
120 
121 LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
122  verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
123  if (builder.getTypeConverterType() != "LLVMTypeConverter")
124  return emitOpError("expected LLVMTypeConverter");
125  return success();
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // Apply...PatternsOp
130 //===----------------------------------------------------------------------===//s
131 
132 void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
133  populateGpuRewritePatterns(patterns);
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // ApplyUnrollVectorsSubgroupMmaOp
138 //===----------------------------------------------------------------------===//
139 
140 /// Pick an unrolling order that will allow tensorcore operation to reuse LHS
141 /// register.
142 static std::optional<SmallVector<int64_t>>
143 gpuMmaUnrollOrder(vector::ContractionOp contract) {
144  SmallVector<int64_t> order;
145  // First make reduction the outer dimensions.
146  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
147  if (vector::isReductionIterator(iter)) {
148  order.push_back(index);
149  }
150  }
151 
152  llvm::SmallDenseSet<int64_t> dims;
153  for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
154  dims.insert(cast<AffineDimExpr>(expr).getPosition());
155  }
156  // Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
157  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
158  if (vector::isParallelIterator(iter) && dims.count(index)) {
159  order.push_back(index);
160  }
161  }
162  // Then the remaining parallel loops.
163  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
164  if (vector::isParallelIterator(iter) && !dims.count(index)) {
165  order.push_back(index);
166  }
167  }
168  return order;
169 }
170 
171 /// Returns the target vector size for the target operation based on the native
172 /// vector size specified with `m`, `n`, and `k`.
173 static std::optional<SmallVector<int64_t>>
174 getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
175  if (auto contract = dyn_cast<vector::ContractionOp>(op)) {
176  int64_t contractRank = contract.getIteratorTypes().size();
177  if (contractRank < 3)
178  return std::nullopt;
179  SmallVector<int64_t> nativeSize(contractRank - 3, 1);
180  nativeSize.append({m, n, k});
181  return nativeSize;
182  }
183  if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
184  int64_t writeRank = writeOp.getVectorType().getRank();
185  if (writeRank < 2)
186  return std::nullopt;
187  SmallVector<int64_t> nativeSize(writeRank - 2, 1);
188  nativeSize.append({m, n});
189  return nativeSize;
190  }
191  if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
192  // Transfer read ops may need different shapes based on how they are being
193  // used. For simplicity just match the shape used by the extract strided op.
194  VectorType sliceType;
195  for (Operation *users : op->getUsers()) {
196  auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
197  if (!extract)
198  return std::nullopt;
199  auto vecType = cast<VectorType>(extract.getResult().getType());
200  if (sliceType && sliceType != vecType)
201  return std::nullopt;
202  sliceType = vecType;
203  }
204  return llvm::to_vector(sliceType.getShape());
205  }
206  if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
207  if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
208  // TODO: The condition for unrolling elementwise should be restricted
209  // only to operations that need unrolling (connected to the contract).
210  if (vecType.getRank() < 2)
211  return std::nullopt;
212 
213  // First check whether there is a slice to infer the shape from. This is
214  // required for cases where the accumulator type differs from the input
215  // types, in which case we will see an `arith.ext_` between the contract
216  // and transfer_read which needs to be unrolled.
217  VectorType sliceType;
218  for (Operation *users : op->getUsers()) {
219  auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
220  if (!extract)
221  return std::nullopt;
222  auto vecType = cast<VectorType>(extract.getResult().getType());
223  if (sliceType && sliceType != vecType)
224  return std::nullopt;
225  sliceType = vecType;
226  }
227  if (sliceType)
228  return llvm::to_vector(sliceType.getShape());
229 
230  // Else unroll for trailing elementwise.
231  SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1);
232  // Map elementwise ops to the output shape.
233  nativeSize.append({m, n});
234  return nativeSize;
235  }
236  }
237  return std::nullopt;
238 }
239 
240 void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
241  RewritePatternSet &patterns) {
242  auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
243  auto contract = dyn_cast<vector::ContractionOp>(op);
244  if (!contract)
245  return std::nullopt;
246  return gpuMmaUnrollOrder(contract);
247  };
248 
249  int64_t m = getM();
250  int64_t n = getN();
251  int64_t k = getK();
252  auto nativeShapeFn =
253  [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> {
254  return getSubgroupMmaNativeVectorSize(op, m, n, k);
255  };
257  patterns, vector::UnrollVectorOptions()
258  .setNativeShapeFn(nativeShapeFn)
259  .setUnrollTraversalOrderFn(unrollOrder));
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // EliminateBarriersOp
264 //===----------------------------------------------------------------------===//
265 
266 void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) {
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // Block and thread mapping utilities.
272 //===----------------------------------------------------------------------===//
273 
274 namespace {
275 /// Local types used for mapping verification.
276 struct MappingKind {};
277 struct BlockMappingKind : MappingKind {};
278 struct ThreadMappingKind : MappingKind {};
279 } // namespace
280 
282 definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
283  Operation *target, const Twine &message) {
284  if (transformOp.has_value())
285  return transformOp->emitDefiniteFailure() << message;
286  return emitDefiniteFailure(target, message);
287 }
288 
289 /// Check if given mapping attributes are one of the desired attributes
290 template <typename MappingKindType>
292 checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
293  scf::ForallOp forallOp) {
294  if (!forallOp.getMapping().has_value()) {
295  return definiteFailureHelper(transformOp, forallOp,
296  "scf.forall op requires a mapping attribute");
297  }
298 
299  bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
300  llvm::IsaPred<GPUBlockMappingAttr>);
301  bool hasWarpgroupMapping = llvm::any_of(
302  forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
303  bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
304  llvm::IsaPred<GPUWarpMappingAttr>);
305  bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
306  llvm::IsaPred<GPUThreadMappingAttr>);
307  int64_t countMappingTypes = 0;
308  countMappingTypes += hasBlockMapping ? 1 : 0;
309  countMappingTypes += hasWarpgroupMapping ? 1 : 0;
310  countMappingTypes += hasWarpMapping ? 1 : 0;
311  countMappingTypes += hasThreadMapping ? 1 : 0;
312  if (countMappingTypes > 1) {
313  return definiteFailureHelper(
314  transformOp, forallOp,
315  "cannot mix different mapping types, use nesting");
316  }
317  if (std::is_same<MappingKindType, BlockMappingKind>::value &&
318  !hasBlockMapping) {
319  return definiteFailureHelper(
320  transformOp, forallOp,
321  "scf.forall op requires a mapping attribute of kind 'block'");
322  }
323  if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
324  !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) {
325  return definiteFailureHelper(transformOp, forallOp,
326  "scf.forall op requires a mapping attribute "
327  "of kind 'thread' or 'warp'");
328  }
329 
330  DenseSet<Attribute> seen;
331  for (Attribute map : forallOp.getMapping()->getValue()) {
332  if (seen.contains(map)) {
333  return definiteFailureHelper(
334  transformOp, forallOp,
335  "duplicate attribute, cannot map different loops "
336  "to the same mapping id");
337  }
338  seen.insert(map);
339  }
340 
341  auto isLinear = [](Attribute a) {
342  return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
343  };
344  if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
345  !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
346  return definiteFailureHelper(
347  transformOp, forallOp,
348  "cannot mix linear and non-linear mapping modes");
349  }
350 
352 }
353 
354 template <typename MappingKindType>
356 verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
357  scf::ForallOp forallOp) {
358  // Check the types of the mapping attributes match.
360  checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp);
361  if (!typeRes.succeeded())
362  return typeRes;
363 
364  // Perform other non-types verifications.
365  if (!forallOp.isNormalized())
366  return definiteFailureHelper(transformOp, forallOp,
367  "unsupported non-normalized loops");
368  if (forallOp.getNumResults() > 0)
369  return definiteFailureHelper(transformOp, forallOp,
370  "only bufferized scf.forall can be mapped");
371  bool useLinearMapping = cast<DeviceMappingAttrInterface>(
372  forallOp.getMapping()->getValue().front())
373  .isLinearMapping();
374  // TODO: This would be more natural with support for Optional<EnumParameter>
375  // in GPUDeviceMappingAttr.
376  int64_t maxNumMappingsSupported =
377  useLinearMapping ? (getMaxEnumValForMappingId() -
378  static_cast<uint64_t>(MappingId::DimZ))
379  : 3;
380  if (forallOp.getRank() > maxNumMappingsSupported) {
381  return definiteFailureHelper(transformOp, forallOp,
382  "scf.forall with rank > ")
383  << maxNumMappingsSupported
384  << " does not lower for the specified mapping attribute type";
385  }
386  auto numParallelIterations =
387  getConstantIntValues(forallOp.getMixedUpperBound());
388  if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
389  return definiteFailureHelper(
390  transformOp, forallOp,
391  "requires statically sized, normalized forall op");
392  }
394 }
395 
396 /// Struct to return the result of the rewrite of a forall operation.
400 };
401 
402 /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
403 template <typename OpTy, typename OperationOrBlock>
404 static void
406  OperationOrBlock *parent, Value replacement,
407  ArrayRef<int64_t> availableMappingSizes) {
408  parent->walk([&](OpTy idOp) {
409  if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
410  rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
411  });
412 }
413 
415  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
416  scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
417  ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) {
418  LDBG("--start rewriteOneForallCommonImpl");
419 
420  // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
421  auto numParallelIterations =
422  getConstantIntValues(forallOp.getMixedUpperBound());
423  assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
424  "requires statically sized, normalized forall op");
425  SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
426  SetVector<Attribute> forallMappingAttrs;
427  forallMappingAttrs.insert(forallOp.getMapping()->getValue().begin(),
428  forallOp.getMapping()->getValue().end());
429  auto comparator = [](Attribute a, Attribute b) -> bool {
430  return cast<DeviceMappingAttrInterface>(a).getMappingId() <
431  cast<DeviceMappingAttrInterface>(b).getMappingId();
432  };
433 
434  // Step 1.b. In the linear case, compute the max mapping to avoid needlessly
435  // mapping all dimensions. In the 3-D mapping case we need to map all
436  // dimensions.
437  DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>(
438  *llvm::max_element(forallMappingAttrs, comparator));
439  DeviceMappingAttrInterface maxLinearMapping;
440  if (maxMapping.isLinearMapping())
441  maxLinearMapping = maxMapping;
442  for (auto attr : gpuIdBuilder.mappingAttributes) {
443  // If attr overflows, just skip.
444  if (maxLinearMapping && comparator(maxLinearMapping, attr))
445  continue;
446  // Try to insert. If element was already present, just continue.
447  if (!forallMappingAttrs.insert(attr))
448  continue;
449  // Otherwise, we have a new insertion without a size -> use size 1.
450  tmpMappingSizes.push_back(1);
451  }
452  LLVM_DEBUG(
453  llvm::interleaveComma(
454  tmpMappingSizes,
455  DBGS() << "----tmpMappingSizes extracted from scf.forall op: ");
456  llvm::dbgs() << "\n");
457 
458  // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
459  SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
460  forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
461  LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
462  DBGS() << "----forallMappingSizes: ");
463  llvm::dbgs() << "\n"; llvm::interleaveComma(
464  forallMappingAttrs, DBGS() << "----forallMappingAttrs: ");
465  llvm::dbgs() << "\n");
466 
467  // Step 3. Generate the mappingIdOps using the provided generator.
468  Location loc = forallOp.getLoc();
469  OpBuilder::InsertionGuard guard(rewriter);
470  rewriter.setInsertionPoint(forallOp);
471  SmallVector<int64_t> originalBasis(availableMappingSizes);
472  bool originalBasisWasProvided = !originalBasis.empty();
473  if (!originalBasisWasProvided) {
474  originalBasis = forallMappingSizes;
475  while (originalBasis.size() < 3)
476  originalBasis.push_back(1);
477  }
478 
479  IdBuilderResult builderResult =
480  gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
481 
482  // Step 4. Map the induction variables to the mappingIdOps, this may involve
483  // a permutation.
484  SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
485  IRMapping bvm;
486  for (auto [iv, dim] : llvm::zip_equal(
487  forallOp.getInductionVars(),
488  forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
489  auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
490  Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
491  bvm.map(iv, peIdOp);
492  }
493 
494  // Step 5. If the originalBasis is already known, create conditionals to
495  // predicate the region. Otherwise, the current forall determines the
496  // originalBasis and no predication occurs.
497  Value predicate;
498  if (originalBasisWasProvided) {
499  SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
500  SmallVector<int64_t> availableMappingSizes =
501  builderResult.availableMappingSizes;
502  SmallVector<Value> activeIdOps = builderResult.activeIdOps;
503  // clang-format off
504  LLVM_DEBUG(
505  llvm::interleaveComma(
506  activeMappingSizes, DBGS() << "----activeMappingSizes: ");
507  llvm::dbgs() << "\n";
508  llvm::interleaveComma(
509  availableMappingSizes, DBGS() << "----availableMappingSizes: ");
510  llvm::dbgs() << "\n";
511  llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: ");
512  llvm::dbgs() << "\n");
513  // clang-format on
514  for (auto [activeId, activeMappingSize, availableMappingSize] :
515  llvm::zip_equal(activeIdOps, activeMappingSizes,
516  availableMappingSizes)) {
517  if (activeMappingSize > availableMappingSize) {
518  return definiteFailureHelper(
519  transformOp, forallOp,
520  "Trying to map to fewer GPU threads than loop iterations but "
521  "overprovisioning is not yet supported. "
522  "Try additional tiling of the before mapping or map to more "
523  "threads.");
524  }
525  if (activeMappingSize == availableMappingSize)
526  continue;
527  Value idx =
528  rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
529  Value tmpPredicate = rewriter.create<arith::CmpIOp>(
530  loc, arith::CmpIPredicate::ult, activeId, idx);
531  LDBG("----predicate: " << tmpPredicate);
532  predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
533  tmpPredicate)
534  : tmpPredicate;
535  }
536  }
537 
538  // Step 6. Move the body of forallOp.
539  // Erase the terminator first, it will not be used.
540  rewriter.eraseOp(forallOp.getTerminator());
541  Block *targetBlock;
542  Block::iterator insertionPoint;
543  if (predicate) {
544  // Step 6.a. If predicated, move at the beginning.
545  auto ifOp = rewriter.create<scf::IfOp>(loc, predicate,
546  /*withElseRegion=*/false);
547  targetBlock = ifOp.thenBlock();
548  insertionPoint = ifOp.thenBlock()->begin();
549  } else {
550  // Step 6.b. Otherwise, move inline just at the rewriter insertion
551  // point.
552  targetBlock = forallOp->getBlock();
553  insertionPoint = rewriter.getInsertionPoint();
554  }
555  Block &sourceBlock = forallOp.getRegion().front();
556  targetBlock->getOperations().splice(insertionPoint,
557  sourceBlock.getOperations());
558 
559  // Step 7. RAUW indices.
560  for (Value loopIndex : forallOp.getInductionVars()) {
561  Value threadIdx = bvm.lookup(loopIndex);
562  rewriter.replaceAllUsesWith(loopIndex, threadIdx);
563  }
564 
565  // Step 8. Erase old op.
566  rewriter.eraseOp(forallOp);
567 
568  LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
569  DBGS() << "----result forallMappingSizes: ");
570  llvm::dbgs() << "\n"; llvm::interleaveComma(
571  mappingIdOps, DBGS() << "----result mappingIdOps: ");
572  llvm::dbgs() << "\n");
573 
574  result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
576 }
577 
578 //===----------------------------------------------------------------------===//
579 // MapForallToBlocks
580 //===----------------------------------------------------------------------===//
581 
583  RewriterBase &rewriter, TransformOpInterface transformOp,
584  scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
585  const GpuIdBuilder &gpuIdBuilder) {
586  LDBG("Start mapForallToBlocksImpl");
587 
588  {
589  // GPU-specific verifications. There is no better place to anchor
590  // those right now: the ForallOp is target-independent and the transform
591  // op does not apply to individual ForallOp.
593  verifyGpuMapping<BlockMappingKind>(transformOp, forallOp);
594  if (!diag.succeeded())
595  return diag;
596  }
597 
598  Location loc = forallOp.getLoc();
599  Block *parentBlock = forallOp->getBlock();
600  Value zero;
601  {
602  // Create an early zero index value for replacements and immediately reset
603  // the insertion point.
604  OpBuilder::InsertionGuard guard(rewriter);
605  rewriter.setInsertionPointToStart(parentBlock);
606  zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
607  }
608 
609  ForallRewriteResult rewriteResult;
611  rewriter, transformOp, forallOp,
612  /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder);
613 
614  // Return if anything goes wrong, use silenceable failure as a match
615  // failure.
616  if (!diag.succeeded())
617  return diag;
618 
619  // If gridDims was not provided already, set it from the return.
620  if (gridDims.empty()) {
621  gridDims = rewriteResult.mappingSizes;
622  while (gridDims.size() < 3)
623  gridDims.push_back(1);
624  }
625  assert(gridDims.size() == 3 && "Need 3-D gridDims");
626 
627  // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
628  // Here, the result of mapping determines the available mapping sizes.
629  replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
630  rewriteResult.mappingSizes);
631 
633 }
634 
637  scf::ForallOp &topLevelForallOp,
638  TransformOpInterface transformOp) {
639  auto walkResult = target->walk([&](scf::ForallOp forallOp) {
640  if (forallOp->getParentOfType<scf::ForallOp>())
641  return WalkResult::advance();
642  if (topLevelForallOp)
643  // TODO: Handle multiple forall if they are independent.
644  return WalkResult::interrupt();
645  topLevelForallOp = forallOp;
646  return WalkResult::advance();
647  });
648 
649  if (walkResult.wasInterrupted() || !topLevelForallOp)
650  return transformOp.emitSilenceableError()
651  << "could not find a unique topLevel scf.forall";
653 }
654 
655 DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
656  transform::TransformRewriter &rewriter, Operation *target,
658  LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
659  auto transformOp = cast<TransformOpInterface>(getOperation());
660 
661  if (!getGenerateGpuLaunch() && !gpuLaunch) {
663  emitSilenceableError()
664  << "Given target is not gpu.launch, set `generate_gpu_launch` "
665  "attribute";
666  diag.attachNote(target->getLoc()) << "when applied to this payload op";
667  return diag;
668  }
669 
670  scf::ForallOp topLevelForallOp;
672  target, topLevelForallOp, transformOp);
673  if (!diag.succeeded()) {
674  diag.attachNote(target->getLoc()) << "when applied to this payload op";
675  return diag;
676  }
677  assert(topLevelForallOp && "expect an scf.forall");
678 
679  SmallVector<int64_t> gridDims{getGridDims()};
680  if (!getGenerateGpuLaunch() && gridDims.size() != 3)
681  return transformOp.emitDefiniteFailure("transform require size-3 mapping");
682 
683  OpBuilder::InsertionGuard guard(rewriter);
684  rewriter.setInsertionPoint(topLevelForallOp);
685 
686  // Generate gpu launch here and move the forall inside
687  if (getGenerateGpuLaunch()) {
689  createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
690  if (!diag.succeeded())
691  return diag;
692 
693  rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
694  Operation *newForallOp = rewriter.clone(*topLevelForallOp);
695  rewriter.eraseOp(topLevelForallOp);
696  topLevelForallOp = cast<scf::ForallOp>(newForallOp);
697  }
698 
699  // The BlockIdBuilder adapts to whatever is thrown at it.
700  bool useLinearMapping = false;
701  if (topLevelForallOp.getMapping()) {
702  auto mappingAttr = cast<DeviceMappingAttrInterface>(
703  topLevelForallOp.getMapping()->getValue().front());
704  useLinearMapping = mappingAttr.isLinearMapping();
705  }
706  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
707 
709  rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
710  if (!diag.succeeded())
711  return diag;
712 
713  // Set the GPU launch configuration for the grid dims late, this is
714  // subject to IR inspection.
715  diag = alterGpuLaunch(rewriter, gpuLaunch,
716  cast<TransformOpInterface>(getOperation()), gridDims[0],
717  gridDims[1], gridDims[2]);
718 
719  results.push_back(gpuLaunch);
720  return diag;
721 }
722 
723 LogicalResult transform::MapForallToBlocks::verify() {
724  if (!getGridDims().empty() && getGridDims().size() != 3) {
725  return emitOpError() << "transform requires empty or size-3 grid_dims";
726  }
727  return success();
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // MapNestedForallToThreads
732 //===----------------------------------------------------------------------===//
733 
735  std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
736  ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
737  int factor, bool useLinearMapping = false) {
738  if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
740  transformOp, forallOp,
741  Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
742  std::to_string(factor));
743  return diag;
744  }
745  if (computeProduct(numParallelIterations) * factor >
746  computeProduct(blockOrGridSizes)) {
748  transformOp, forallOp,
749  Twine("the number of required parallel resources (blocks or "
750  "threads) ") +
751  std::to_string(computeProduct(numParallelIterations) * factor) +
752  std::string(" overflows the number of available resources ") +
753  std::to_string(computeProduct(blockOrGridSizes)));
754  return diag;
755  }
757 }
758 
760 getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
761  scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
762  int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
763  auto mappingAttr = cast<DeviceMappingAttrInterface>(
764  forallOp.getMapping()->getValue().front());
765  bool useLinearMapping = mappingAttr.isLinearMapping();
766 
767  // Sanity checks that may result in runtime verification errors.
768  auto numParallelIterations =
769  getConstantIntValues((forallOp.getMixedUpperBound()));
770  if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
771  return definiteFailureHelper(
772  transformOp, forallOp,
773  "requires statically sized, normalized forall op");
774  }
775  int64_t factor = 1;
776  if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
777  factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
778  } else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
779  factor = warpSize;
780  }
782  checkMappingSpec(transformOp, forallOp, numParallelIterations.value(),
783  blockSizes, factor, useLinearMapping);
784  if (!diag.succeeded())
785  return diag;
786 
787  // Start mapping.
788  MLIRContext *ctx = forallOp.getContext();
789  gpuIdBuilder =
791  .Case([&](GPUWarpgroupMappingAttr) {
792  return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
793  })
794  .Case([&](GPUWarpMappingAttr) {
795  return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
796  })
797  .Case([&](GPUThreadMappingAttr) {
798  return GpuThreadIdBuilder(ctx, useLinearMapping);
799  })
800  .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
801  llvm_unreachable("unknown mapping attribute");
802  });
804 }
805 
807  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
808  scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize,
809  bool syncAfterDistribute) {
810 
811  {
812  // GPU-specific verifications. There is no better place to anchor
813  // those right now: the ForallOp is target-independent and the transform
814  // op does not apply to individual ForallOp.
816  verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp);
817  if (!diag.succeeded())
818  return diag;
819  }
820 
821  GpuIdBuilder gpuIdBuilder;
822  {
823  // Try to construct the id builder, if it fails, return.
825  transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
826  if (!diag.succeeded())
827  return diag;
828  }
829 
830  Location loc = forallOp.getLoc();
831  OpBuilder::InsertionGuard g(rewriter);
832  // Insert after to allow for syncthreads after `forall` is erased.
833  rewriter.setInsertionPointAfter(forallOp);
834  ForallRewriteResult rewriteResult;
836  rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
837  if (!diag.succeeded())
838  return diag;
839  // Add a syncthreads if needed. TODO: warpsync
840  if (syncAfterDistribute)
841  rewriter.create<BarrierOp>(loc);
842 
844 }
845 
847  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
848  Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
849  bool syncAfterDistribute) {
850  LDBG("Start mapNestedForallToThreadsImpl");
851  if (blockDims.size() != 3) {
852  return definiteFailureHelper(transformOp, target,
853  "requires size-3 thread mapping");
854  }
855 
856  // Create an early zero index value for replacements.
857  Location loc = target->getLoc();
858  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
860  WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
862  rewriter, transformOp, forallOp, blockDims, warpSize,
863  syncAfterDistribute);
864  if (diag.isDefiniteFailure())
865  return WalkResult::interrupt();
866  if (diag.succeeded())
867  return WalkResult::skip();
868  return WalkResult::advance();
869  });
870  if (walkResult.wasInterrupted())
871  return diag;
872 
873  // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
874  // Here, the result of mapping determines the available mapping sizes.
875  replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
876  blockDims);
877 
879 }
880 
881 DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
882  transform::TransformRewriter &rewriter, Operation *target,
883  ApplyToEachResultList &results, TransformState &state) {
884  LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
885  auto transformOp = cast<TransformOpInterface>(getOperation());
886 
887  // Basic high-level verifications.
888  if (!gpuLaunch)
889  return emitSilenceableError() << "Given target is not a gpu.launch";
890 
891  // Mapping to block ids.
892  SmallVector<int64_t> blockDims{getBlockDims()};
894  checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
895  blockDims[0], blockDims[1], blockDims[2]);
896  if (diag.isSilenceableFailure()) {
897  diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large";
898  return diag;
899  }
900 
901  // Set the GPU launch configuration for the block dims early, this is not
902  // subject to IR inspection.
903  diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
904  std::nullopt, std::nullopt, blockDims[0], blockDims[1],
905  blockDims[2]);
906 
907  rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
908  diag =
909  mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
910  getWarpSize(), getSyncAfterDistribute());
911 
912  results.push_back(gpuLaunch.getOperation());
913  return diag;
914 }
915 
916 //===----------------------------------------------------------------------===//
917 // Transform op registration
918 //===----------------------------------------------------------------------===//
919 
920 namespace {
921 /// Registers new ops and declares PDL as dependent dialect since the
922 /// additional ops are using PDL types for operands and results.
923 class GPUTransformDialectExtension
925  GPUTransformDialectExtension> {
926 public:
927  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
928 
929  GPUTransformDialectExtension() {
930  declareGeneratedDialect<scf::SCFDialect>();
931  declareGeneratedDialect<arith::ArithDialect>();
932  declareGeneratedDialect<GPUDialect>();
933  registerTransformOps<
934 #define GET_OP_LIST
935 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
936  >();
937  }
938 };
939 } // namespace
940 
941 #define GET_OP_CLASSES
942 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
943 
945  registry.addExtensions<GPUTransformDialectExtension>();
946 }
static constexpr int64_t kSharedMemorySpace
static DiagnosedSilenceableFailure checkMappingAttributeTypes(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp)
Check if given mapping attributes are one of the desired attributes.
static std::optional< SmallVector< int64_t > > getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k)
Returns the target vector size for the target operation based on the native vector size specified wit...
static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > availableMappingSizes, ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder)
static DiagnosedSilenceableFailure definiteFailureHelper(std::optional< TransformOpInterface > transformOp, Operation *target, const Twine &message)
static DiagnosedSilenceableFailure checkMappingSpec(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > numParallelIterations, ArrayRef< int64_t > blockOrGridSizes, int factor, bool useLinearMapping=false)
static void replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, OperationOrBlock *parent, Value replacement, ArrayRef< int64_t > availableMappingSizes)
Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
static std::optional< SmallVector< int64_t > > gpuMmaUnrollOrder(vector::ContractionOp contract)
Pick an unrolling order that will allow tensorcore operation to reuse LHS register.
#define DBGS()
static DiagnosedSilenceableFailure verifyGpuMapping(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp)
#define LDBG(X)
static DiagnosedSilenceableFailure getThreadIdBuilder(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > blockSizes, int64_t warpSize, GpuIdBuilder &gpuIdBuilder)
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
Base type for affine expression.
Definition: AffineExpr.h:68
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:36
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerTransformDialectExtension(DialectRegistry &registry)
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
DiagnosedSilenceableFailure mapOneForallToThreadsImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > blockDims, int64_t warpSize, bool syncAfterDistribute)
Search scf.forall ops nested under target and map each such op to an explicit GPU implementation alon...
DiagnosedSilenceableFailure findTopLevelForallOp(Operation *target, scf::ForallOp &topLevelForallOp, TransformOpInterface transformOp)
Find the unique top level scf::ForallOp within a given target op.
DiagnosedSilenceableFailure alterGpuLaunch(RewriterBase &rewriter, mlir::gpu::LaunchOp gpuLaunch, TransformOpInterface transformOp, std::optional< int64_t > gridDimX=std::nullopt, std::optional< int64_t > gridDimY=std::nullopt, std::optional< int64_t > gridDimZ=std::nullopt, std::optional< int64_t > blockDimX=std::nullopt, std::optional< int64_t > blockDimY=std::nullopt, std::optional< int64_t > blockDimZ=std::nullopt)
Alter kernel configuration of the given kernel.
DiagnosedSilenceableFailure createGpuLaunch(RewriterBase &rewriter, Location loc, TransformOpInterface transformOp, mlir::gpu::LaunchOp &launchOp, std::optional< int64_t > gridDimX=std::nullopt, std::optional< int64_t > gridDimY=std::nullopt, std::optional< int64_t > gridDimZ=std::nullopt, std::optional< int64_t > blockDimX=std::nullopt, std::optional< int64_t > blockDimY=std::nullopt, std::optional< int64_t > blockDimZ=std::nullopt)
Create an empty-body gpu::LaunchOp using the provided kernel settings and put a terminator within.
DiagnosedSilenceableFailure mapForallToBlocksImpl(RewriterBase &rewriter, TransformOpInterface transformOp, scf::ForallOp forallOp, SmallVectorImpl< int64_t > &gridDims, const GpuIdBuilder &gpuIdBuilder)
Map the top level scf.forall op to GPU blocks.
DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp, std::optional< int64_t > gridDimX, std::optional< int64_t > gridDimY, std::optional< int64_t > gridDimZ, std::optional< int64_t > blockDimX, std::optional< int64_t > blockDimY, std::optional< int64_t > blockDimZ)
Determine if the size of the kernel configuration is supported by the GPU architecture being used.
Definition: Utils.cpp:231
DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, Operation *target, ArrayRef< int64_t > blockDims, int64_t warpSize, bool syncAfterDistribute)
Search scf.forall ops nested under target and map each such op to an explicit GPU implementation alon...
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:152
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:147
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
Include the generated interface declarations.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition: Passes.h:91
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns)
Erase barriers that do not enforce conflicting memory side effects.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate GpuSubgroupReduce pattern to NVVM.
Struct to return the result of the rewrite of a forall operation.
SmallVector< Value > mappingIds
SmallVector< int64_t > mappingSizes
Builder for gpu::BlockIdOps used to map scf.forall to blocks.
Definition: Utils.h:83
Helper struct for configuring the rewrite of mapped scf.forall ops to various gpu id configurations.
Definition: Utils.h:63
SmallVector< DeviceMappingAttrInterface > mappingAttributes
The mapping attributes targeted by this generator.
Definition: Utils.h:72
GpuIdBuilderFnType idBuilder
The constructor that builds the concrete IR for mapping ids.
Definition: Utils.h:75
Builder for warp ids used to map scf.forall to reindexed threads.
Definition: Utils.h:116
Builder for warp ids used to map scf.forall to reindexed warps.
Definition: Utils.h:105
Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
Definition: Utils.h:92
Helper type for functions that generate ids for the mapping of a scf.forall.
Definition: Utils.h:38
SmallVector< int64_t > availableMappingSizes
Definition: Utils.h:43
SmallVector< Value > mappingIdOps
Definition: Utils.h:40
SmallVector< Value > activeIdOps
Definition: Utils.h:49
SmallVector< int64_t > activeMappingSizes
Definition: Utils.h:46
Options that control the vector unrolling.