MLIR  22.0.0git
GPUTransformOps.cpp
Go to the documentation of this file.
1 //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
31 #include "mlir/IR/AffineExpr.h"
32 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/IRMapping.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/OpDefinition.h"
37 #include "mlir/IR/Visitors.h"
38 #include "mlir/Support/LLVM.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/TypeSwitch.h"
43 #include "llvm/Support/DebugLog.h"
44 #include "llvm/Support/ErrorHandling.h"
45 #include "llvm/Support/InterleavedRange.h"
46 #include "llvm/Support/LogicalResult.h"
47 #include <optional>
48 #include <type_traits>
49 
50 using namespace mlir;
51 using namespace mlir::gpu;
52 using namespace mlir::transform;
53 using namespace mlir::transform::gpu;
54 
55 #define DEBUG_TYPE "gpu-transforms"
56 
57 //===----------------------------------------------------------------------===//
58 // Apply...ConversionPatternsOp
59 //===----------------------------------------------------------------------===//
60 
61 void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
62  TypeConverter &typeConverter, RewritePatternSet &patterns) {
63  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
64  // NVVM uses alloca in the default address space to represent private
65  // memory allocations, so drop private annotations. NVVM uses address
66  // space 3 for shared memory. NVVM uses the default address space to
67  // represent global memory.
68  // Used in populateGpuToNVVMConversionPatternsso attaching here for now.
69  // TODO: We should have a single to_nvvm_type_converter.
71  llvmTypeConverter, [](AddressSpace space) -> unsigned {
72  switch (space) {
73  case AddressSpace::Global:
74  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
75  case AddressSpace::Workgroup:
76  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
77  case AddressSpace::Private:
78  return 0;
79  }
80  llvm_unreachable("unknown address space enum value");
81  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
82  });
83  // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now.
84  // TODO: We should have a single to_nvvm_type_converter.
85  llvmTypeConverter.addConversion(
86  [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); });
87  // Set higher benefit, so patterns will run before generic LLVM lowering.
89  getBenefit());
90 }
91 
92 LogicalResult
93 transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
94  transform::TypeConverterBuilderOpInterface builder) {
95  if (builder.getTypeConverterType() != "LLVMTypeConverter")
96  return emitOpError("expected LLVMTypeConverter");
97  return success();
98 }
99 
100 void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
101  TypeConverter &typeConverter, RewritePatternSet &patterns) {
102  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
104 }
105 
106 LogicalResult
107 transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
108  transform::TypeConverterBuilderOpInterface builder) {
109  if (builder.getTypeConverterType() != "LLVMTypeConverter")
110  return emitOpError("expected LLVMTypeConverter");
111  return success();
112 }
113 
114 void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
115  populatePatterns(TypeConverter &typeConverter,
117  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
119 }
120 
121 LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
122  verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
123  if (builder.getTypeConverterType() != "LLVMTypeConverter")
124  return emitOpError("expected LLVMTypeConverter");
125  return success();
126 }
127 
128 void transform::ApplyGPUToROCDLConversionPatternsOp::populatePatterns(
129  TypeConverter &typeConverter, RewritePatternSet &patterns) {
130  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
132  llvmTypeConverter, [](AddressSpace space) {
133  switch (space) {
134  case AddressSpace::Global:
135  return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
136  case AddressSpace::Workgroup:
137  return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
138  case AddressSpace::Private:
139  return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
140  }
141  llvm_unreachable("unknown address space enum value");
142  });
143  FailureOr<amdgpu::Chipset> maybeChipset =
144  amdgpu::Chipset::parse(getChipset());
145  assert(llvm::succeeded(maybeChipset) && "expected valid chipset");
147  llvmTypeConverter, patterns, mlir::gpu::amd::Runtime::HIP, *maybeChipset);
148 }
149 
150 LogicalResult
151 transform::ApplyGPUToROCDLConversionPatternsOp::verifyTypeConverter(
152  transform::TypeConverterBuilderOpInterface builder) {
153  FailureOr<amdgpu::Chipset> maybeChipset =
154  amdgpu::Chipset::parse(getChipset());
155  if (failed(maybeChipset)) {
156  return emitOpError("Invalid chipset name: " + getChipset());
157  }
158  if (builder.getTypeConverterType() != "LLVMTypeConverter")
159  return emitOpError("expected LLVMTypeConverter");
160  return success();
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // Apply...PatternsOp
165 //===----------------------------------------------------------------------===//s
166 
167 void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
169 }
170 
171 void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
173  std::optional<StringRef> chipsetName = getChipset();
174  std::optional<amdgpu::Chipset> maybeChipset;
175  if (chipsetName) {
176  FailureOr<amdgpu::Chipset> parsedChipset =
177  amdgpu::Chipset::parse(*chipsetName);
178  assert(llvm::succeeded(parsedChipset) && "expected valid chipset");
179  maybeChipset = parsedChipset;
180  }
181 
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // ApplyUnrollVectorsSubgroupMmaOp
187 //===----------------------------------------------------------------------===//
188 
189 /// Pick an unrolling order that will allow tensorcore operation to reuse LHS
190 /// register.
191 static std::optional<SmallVector<int64_t>>
192 gpuMmaUnrollOrder(vector::ContractionOp contract) {
193  SmallVector<int64_t> order;
194  // First make reduction the outer dimensions.
195  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
196  if (vector::isReductionIterator(iter)) {
197  order.push_back(index);
198  }
199  }
200 
201  llvm::SmallDenseSet<int64_t> dims;
202  for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
203  dims.insert(cast<AffineDimExpr>(expr).getPosition());
204  }
205  // Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
206  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
207  if (vector::isParallelIterator(iter) && dims.count(index)) {
208  order.push_back(index);
209  }
210  }
211  // Then the remaining parallel loops.
212  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
213  if (vector::isParallelIterator(iter) && !dims.count(index)) {
214  order.push_back(index);
215  }
216  }
217  return order;
218 }
219 
220 /// Returns the target vector size for the target operation based on the native
221 /// vector size specified with `m`, `n`, and `k`.
222 static std::optional<SmallVector<int64_t>>
223 getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
224  if (auto contract = dyn_cast<vector::ContractionOp>(op)) {
225  int64_t contractRank = contract.getIteratorTypes().size();
226  if (contractRank < 3)
227  return std::nullopt;
228  SmallVector<int64_t> nativeSize(contractRank - 3, 1);
229  nativeSize.append({m, n, k});
230  return nativeSize;
231  }
232  if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
233  int64_t writeRank = writeOp.getVectorType().getRank();
234  if (writeRank < 2)
235  return std::nullopt;
236  SmallVector<int64_t> nativeSize(writeRank - 2, 1);
237  nativeSize.append({m, n});
238  return nativeSize;
239  }
240  if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
241  // Transfer read ops may need different shapes based on how they are being
242  // used. For simplicity just match the shape used by the extract strided op.
243  VectorType sliceType;
244  for (Operation *users : op->getUsers()) {
245  auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
246  if (!extract)
247  return std::nullopt;
248  auto vecType = cast<VectorType>(extract.getResult().getType());
249  if (sliceType && sliceType != vecType)
250  return std::nullopt;
251  sliceType = vecType;
252  }
253  return llvm::to_vector(sliceType.getShape());
254  }
255  if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
256  if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
257  // TODO: The condition for unrolling elementwise should be restricted
258  // only to operations that need unrolling (connected to the contract).
259  if (vecType.getRank() < 2)
260  return std::nullopt;
261 
262  // First check whether there is a slice to infer the shape from. This is
263  // required for cases where the accumulator type differs from the input
264  // types, in which case we will see an `arith.ext_` between the contract
265  // and transfer_read which needs to be unrolled.
266  VectorType sliceType;
267  for (Operation *users : op->getUsers()) {
268  auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
269  if (!extract)
270  return std::nullopt;
271  auto vecType = cast<VectorType>(extract.getResult().getType());
272  if (sliceType && sliceType != vecType)
273  return std::nullopt;
274  sliceType = vecType;
275  }
276  if (sliceType)
277  return llvm::to_vector(sliceType.getShape());
278 
279  // Else unroll for trailing elementwise.
280  SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1);
281  // Map elementwise ops to the output shape.
282  nativeSize.append({m, n});
283  return nativeSize;
284  }
285  }
286  return std::nullopt;
287 }
288 
289 void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
291  auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
292  auto contract = dyn_cast<vector::ContractionOp>(op);
293  if (!contract)
294  return std::nullopt;
295  return gpuMmaUnrollOrder(contract);
296  };
297 
298  int64_t m = getM();
299  int64_t n = getN();
300  int64_t k = getK();
301  auto nativeShapeFn =
302  [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> {
303  return getSubgroupMmaNativeVectorSize(op, m, n, k);
304  };
305  vector::populateVectorUnrollPatterns(
307  .setNativeShapeFn(nativeShapeFn)
308  .setUnrollTraversalOrderFn(unrollOrder));
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // EliminateBarriersOp
313 //===----------------------------------------------------------------------===//
314 
315 void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) {
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // Block and thread mapping utilities.
321 //===----------------------------------------------------------------------===//
322 
323 namespace {
324 /// Local types used for mapping verification.
325 struct MappingKind {};
326 struct BlockMappingKind : MappingKind {};
327 struct ThreadMappingKind : MappingKind {};
328 } // namespace
329 
331 definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
332  Operation *target, const Twine &message) {
333  if (transformOp.has_value())
334  return transformOp->emitDefiniteFailure() << message;
335  return emitDefiniteFailure(target, message);
336 }
337 
338 /// Check if given mapping attributes are one of the desired attributes
339 template <typename MappingKindType>
341 checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
342  scf::ForallOp forallOp) {
343  if (!forallOp.getMapping().has_value()) {
344  return definiteFailureHelper(transformOp, forallOp,
345  "scf.forall op requires a mapping attribute");
346  }
347 
348  bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
349  llvm::IsaPred<GPUBlockMappingAttr>);
350  bool hasWarpgroupMapping = llvm::any_of(
351  forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
352  bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
353  llvm::IsaPred<GPUWarpMappingAttr>);
354  bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
355  llvm::IsaPred<GPUThreadMappingAttr>);
356  bool hasLaneMapping = llvm::any_of(forallOp.getMapping().value(),
357  llvm::IsaPred<GPULaneMappingAttr>);
358  int64_t countMappingTypes = 0;
359  countMappingTypes += hasBlockMapping ? 1 : 0;
360  countMappingTypes += hasWarpgroupMapping ? 1 : 0;
361  countMappingTypes += hasWarpMapping ? 1 : 0;
362  countMappingTypes += hasThreadMapping ? 1 : 0;
363  countMappingTypes += hasLaneMapping ? 1 : 0;
364  if (countMappingTypes > 1) {
365  return definiteFailureHelper(
366  transformOp, forallOp,
367  "cannot mix different mapping types, use nesting");
368  }
369  if (std::is_same<MappingKindType, BlockMappingKind>::value &&
370  !hasBlockMapping) {
371  return definiteFailureHelper(
372  transformOp, forallOp,
373  "scf.forall op requires a mapping attribute of kind 'block'");
374  }
375  if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
376  !hasLaneMapping && !hasThreadMapping && !hasWarpMapping &&
377  !hasWarpgroupMapping) {
378  return definiteFailureHelper(transformOp, forallOp,
379  "scf.forall op requires a mapping attribute "
380  "of kind 'thread' or 'warp'");
381  }
382 
383  DenseSet<Attribute> seen;
384  for (Attribute map : forallOp.getMapping()->getValue()) {
385  if (seen.contains(map)) {
386  return definiteFailureHelper(
387  transformOp, forallOp,
388  "duplicate attribute, cannot map different loops "
389  "to the same mapping id");
390  }
391  seen.insert(map);
392  }
393 
394  auto isLinear = [](DeviceMappingAttrInterface attr) {
395  return attr.isLinearMapping();
396  };
397  if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) &&
398  !llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) {
399  return definiteFailureHelper(
400  transformOp, forallOp,
401  "cannot mix linear and non-linear mapping modes");
402  }
403 
404  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
405  forallOp.getDeviceMaskingAttr();
406  if (succeeded(maybeMaskingAttr) && *maybeMaskingAttr &&
407  !forallOp.usesLinearMapping()) {
408  return definiteFailureHelper(
409  transformOp, forallOp,
410  "device masking is only available in linear mapping mode");
411  }
412 
414 }
415 
416 template <typename MappingKindType>
418 verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
419  scf::ForallOp forallOp) {
420  // Check the types of the mapping attributes match.
422  checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp);
423  if (!typeRes.succeeded())
424  return typeRes;
425 
426  // Perform other non-types verifications.
427  if (!forallOp.isNormalized())
428  return definiteFailureHelper(transformOp, forallOp,
429  "unsupported non-normalized loops");
430  if (forallOp.getNumResults() > 0)
431  return definiteFailureHelper(transformOp, forallOp,
432  "only bufferized scf.forall can be mapped");
433  bool useLinearMapping = forallOp.usesLinearMapping();
434  // TODO: This would be more natural with support for Optional<EnumParameter>
435  // in GPUDeviceMappingAttr.
436  int64_t maxNumMappingsSupported =
437  useLinearMapping ? (getMaxEnumValForMappingId() -
438  static_cast<uint64_t>(MappingId::DimZ))
439  : 3;
440  if (forallOp.getRank() > maxNumMappingsSupported) {
441  return definiteFailureHelper(transformOp, forallOp,
442  "scf.forall with rank > ")
443  << maxNumMappingsSupported
444  << " does not lower for the specified mapping attribute type";
445  }
446  auto numParallelIterations =
447  getConstantIntValues(forallOp.getMixedUpperBound());
448  if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
449  return definiteFailureHelper(
450  transformOp, forallOp,
451  "requires statically sized, normalized forall op");
452  }
454 }
455 
456 /// Struct to return the result of the rewrite of a forall operation.
460 };
461 
462 /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
463 template <typename OpTy, typename OperationOrBlock>
464 static void
466  OperationOrBlock *parent, Value replacement,
467  ArrayRef<int64_t> availableMappingSizes) {
468  parent->walk([&](OpTy idOp) {
469  if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
470  rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
471  });
472 }
473 
475  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
476  scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
477  ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) {
478  LDBG() << "--start rewriteOneForallCommonImpl";
479 
480  // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
481  auto numParallelIterations =
482  getConstantIntValues(forallOp.getMixedUpperBound());
483  assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
484  "requires statically sized, normalized forall op");
485  SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
486  SmallVector<DeviceMappingAttrInterface> forallMappingAttrsVec =
487  forallOp.getDeviceMappingAttrs();
488  SetVector<Attribute> forallMappingAttrs;
489  forallMappingAttrs.insert_range(forallMappingAttrsVec);
490  auto comparator = [](Attribute a, Attribute b) -> bool {
491  return cast<DeviceMappingAttrInterface>(a).getMappingId() <
492  cast<DeviceMappingAttrInterface>(b).getMappingId();
493  };
494 
495  // Step 1.b. In the linear case, compute the max mapping to avoid needlessly
496  // mapping all dimensions. In the 3-D mapping case we need to map all
497  // dimensions.
498  DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>(
499  *llvm::max_element(forallMappingAttrs, comparator));
500  DeviceMappingAttrInterface maxLinearMapping;
501  if (maxMapping.isLinearMapping())
502  maxLinearMapping = maxMapping;
503  for (auto attr : gpuIdBuilder.mappingAttributes) {
504  // If attr overflows, just skip.
505  if (maxLinearMapping && comparator(maxLinearMapping, attr))
506  continue;
507  // Try to insert. If element was already present, just continue.
508  if (!forallMappingAttrs.insert(attr))
509  continue;
510  // Otherwise, we have a new insertion without a size -> use size 1.
511  tmpMappingSizes.push_back(1);
512  }
513  LDBG() << "----tmpMappingSizes extracted from scf.forall op: "
514  << llvm::interleaved(tmpMappingSizes);
515 
516  // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
517  SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
518  forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
519  LDBG() << "----forallMappingSizes: " << llvm::interleaved(forallMappingSizes);
520  LDBG() << "----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs);
521 
522  // Step 3. Generate the mappingIdOps using the provided generator.
523  Location loc = forallOp.getLoc();
524  OpBuilder::InsertionGuard guard(rewriter);
525  rewriter.setInsertionPoint(forallOp);
526  SmallVector<int64_t> originalBasis(availableMappingSizes);
527  bool originalBasisWasProvided = !originalBasis.empty();
528  if (!originalBasisWasProvided) {
529  LDBG() << "----originalBasis was not provided, deriving it and there will "
530  "be no "
531  "predication";
532  originalBasis = forallMappingSizes;
533  while (originalBasis.size() < 3)
534  originalBasis.push_back(1);
535  } else {
536  LDBG() << "----originalBasis was provided, using it, there will be "
537  "predication";
538  }
539  LDBG() << "------originalBasis: " << llvm::interleaved(originalBasis);
540 
541  IdBuilderResult builderResult =
542  gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
543  if (!builderResult.errorMsg.empty())
544  return definiteFailureHelper(transformOp, forallOp, builderResult.errorMsg);
545 
546  LDBG() << builderResult;
547 
548  // Step 4. Map the induction variables to the mappingIdOps, this may involve
549  // a permutation.
550  SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
551  IRMapping bvm;
552  for (auto [iv, dim] : llvm::zip_equal(
553  forallOp.getInductionVars(),
554  forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
555  auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
556  Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
557  LDBG() << "----map: " << iv << " to " << peIdOp;
558  bvm.map(iv, peIdOp);
559  }
560 
561  // Step 5. If the originalBasis is already known, create conditionals to
562  // predicate the region. Otherwise, the current forall determines the
563  // originalBasis and no predication occurs.
564  Value predicate;
565  if (originalBasisWasProvided) {
566  for (Value tmpPredicate : builderResult.predicateOps) {
567  predicate = predicate ? arith::AndIOp::create(rewriter, loc, predicate,
568  tmpPredicate)
569  : tmpPredicate;
570  }
571  }
572 
573  // Step 6. Move the body of forallOp.
574  // Erase the terminator first, it will not be used.
575  rewriter.eraseOp(forallOp.getTerminator());
576  Block *targetBlock;
577  Block::iterator insertionPoint;
578  if (predicate) {
579  // Step 6.a. If predicated, move at the beginning.
580  auto ifOp = scf::IfOp::create(rewriter, loc, predicate,
581  /*withElseRegion=*/false);
582  targetBlock = ifOp.thenBlock();
583  insertionPoint = ifOp.thenBlock()->begin();
584  } else {
585  // Step 6.b. Otherwise, move inline just at the rewriter insertion
586  // point.
587  targetBlock = forallOp->getBlock();
588  insertionPoint = rewriter.getInsertionPoint();
589  }
590  Block &sourceBlock = forallOp.getRegion().front();
591  targetBlock->getOperations().splice(insertionPoint,
592  sourceBlock.getOperations());
593 
594  // Step 7. RAUW indices.
595  for (Value loopIndex : forallOp.getInductionVars()) {
596  Value threadIdx = bvm.lookup(loopIndex);
597  rewriter.replaceAllUsesWith(loopIndex, threadIdx);
598  }
599 
600  // Step 8. Erase old op.
601  rewriter.eraseOp(forallOp);
602 
603  LDBG() << "----result forallMappingSizes: "
604  << llvm::interleaved(forallMappingSizes);
605  LDBG() << "----result mappingIdOps: " << llvm::interleaved(mappingIdOps);
606 
607  result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // MapForallToBlocks
613 //===----------------------------------------------------------------------===//
614 
616  RewriterBase &rewriter, TransformOpInterface transformOp,
617  scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
618  const GpuIdBuilder &gpuIdBuilder) {
619  LDBG() << "Start mapForallToBlocksImpl";
620 
621  {
622  // GPU-specific verifications. There is no better place to anchor
623  // those right now: the ForallOp is target-independent and the transform
624  // op does not apply to individual ForallOp.
626  verifyGpuMapping<BlockMappingKind>(transformOp, forallOp);
627  if (!diag.succeeded())
628  return diag;
629  }
630 
631  Location loc = forallOp.getLoc();
632  Block *parentBlock = forallOp->getBlock();
633  Value zero;
634  {
635  // Create an early zero index value for replacements and immediately reset
636  // the insertion point.
637  OpBuilder::InsertionGuard guard(rewriter);
638  rewriter.setInsertionPointToStart(parentBlock);
639  zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
640  }
641 
642  ForallRewriteResult rewriteResult;
644  rewriter, transformOp, forallOp,
645  /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder);
646 
647  // Return if anything goes wrong, use silenceable failure as a match
648  // failure.
649  if (!diag.succeeded())
650  return diag;
651 
652  // If gridDims was not provided already, set it from the return.
653  if (gridDims.empty()) {
654  gridDims = rewriteResult.mappingSizes;
655  while (gridDims.size() < 3)
656  gridDims.push_back(1);
657  }
658  assert(gridDims.size() == 3 && "Need 3-D gridDims");
659 
660  // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
661  // Here, the result of mapping determines the available mapping sizes.
662  replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
663  rewriteResult.mappingSizes);
664 
666 }
667 
670  scf::ForallOp &topLevelForallOp,
671  TransformOpInterface transformOp) {
672  auto walkResult = target->walk([&](scf::ForallOp forallOp) {
673  if (forallOp->getParentOfType<scf::ForallOp>())
674  return WalkResult::advance();
675  if (topLevelForallOp)
676  // TODO: Handle multiple forall if they are independent.
677  return WalkResult::interrupt();
678  topLevelForallOp = forallOp;
679  return WalkResult::advance();
680  });
681 
682  if (walkResult.wasInterrupted() || !topLevelForallOp)
683  return transformOp.emitSilenceableError()
684  << "could not find a unique topLevel scf.forall";
686 }
687 
688 DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
689  transform::TransformRewriter &rewriter, Operation *target,
691  LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
692  auto transformOp = cast<TransformOpInterface>(getOperation());
693 
694  if (!getGenerateGpuLaunch() && !gpuLaunch) {
696  emitSilenceableError()
697  << "Given target is not gpu.launch, set `generate_gpu_launch` "
698  "attribute";
699  diag.attachNote(target->getLoc()) << "when applied to this payload op";
700  return diag;
701  }
702 
703  scf::ForallOp topLevelForallOp;
705  target, topLevelForallOp, transformOp);
706  if (!diag.succeeded()) {
707  diag.attachNote(target->getLoc()) << "when applied to this payload op";
708  return diag;
709  }
710  assert(topLevelForallOp && "expect an scf.forall");
711 
712  SmallVector<int64_t> gridDims{getGridDims()};
713  if (!getGenerateGpuLaunch() && gridDims.size() != 3)
714  return transformOp.emitDefiniteFailure("transform require size-3 mapping");
715 
716  OpBuilder::InsertionGuard guard(rewriter);
717  rewriter.setInsertionPoint(topLevelForallOp);
718 
719  // Generate gpu launch here and move the forall inside
720  if (getGenerateGpuLaunch()) {
722  createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
723  if (!diag.succeeded())
724  return diag;
725 
726  rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
727  Operation *newForallOp = rewriter.clone(*topLevelForallOp);
728  rewriter.eraseOp(topLevelForallOp);
729  topLevelForallOp = cast<scf::ForallOp>(newForallOp);
730  }
731 
732  // The BlockIdBuilder adapts to whatever is thrown at it.
733  bool useLinearMapping = false;
734  if (topLevelForallOp.getMapping())
735  useLinearMapping = topLevelForallOp.usesLinearMapping();
736 
737  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
738  topLevelForallOp.getDeviceMaskingAttr();
739  assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
740  assert((!*maybeMaskingAttr || useLinearMapping) &&
741  "masking requires linear mapping");
742 
743  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping,
744  *maybeMaskingAttr);
745 
747  rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
748  if (!diag.succeeded())
749  return diag;
750 
751  // Set the GPU launch configuration for the grid dims late, this is
752  // subject to IR inspection.
753  diag = alterGpuLaunch(rewriter, gpuLaunch,
754  cast<TransformOpInterface>(getOperation()), gridDims[0],
755  gridDims[1], gridDims[2]);
756 
757  results.push_back(gpuLaunch);
758  return diag;
759 }
760 
761 LogicalResult transform::MapForallToBlocks::verify() {
762  if (!getGridDims().empty() && getGridDims().size() != 3) {
763  return emitOpError() << "transform requires empty or size-3 grid_dims";
764  }
765  return success();
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // MapNestedForallToThreads
770 //===----------------------------------------------------------------------===//
771 
773  std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
774  ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
775  int factor, bool useLinearMapping = false) {
776  if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
778  transformOp, forallOp,
779  Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
780  Twine(factor));
781  return diag;
782  }
783  if (computeProduct(numParallelIterations) * factor >
784  computeProduct(blockOrGridSizes)) {
786  transformOp, forallOp,
787  Twine("the number of required parallel resources (blocks or "
788  "threads) ") +
789  Twine(computeProduct(numParallelIterations) * factor) +
790  " overflows the number of available resources " +
791  Twine(computeProduct(blockOrGridSizes)));
792  return diag;
793  }
795 }
796 
798 getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
799  scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
800  int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
801  DeviceMappingAttrInterface mappingAttr =
802  forallOp.getDeviceMappingAttrs().front();
803  bool useLinearMapping = mappingAttr.isLinearMapping();
804 
805  // Sanity checks that may result in runtime verification errors.
806  auto numParallelIterations =
807  getConstantIntValues((forallOp.getMixedUpperBound()));
808  if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
809  return definiteFailureHelper(
810  transformOp, forallOp,
811  "requires statically sized, normalized forall op");
812  }
813  int64_t factor = 1;
814  if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
815  factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
816  } else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
817  factor = warpSize;
818  }
820  checkMappingSpec(transformOp, forallOp, numParallelIterations.value(),
821  blockSizes, factor, useLinearMapping);
822  if (!diag.succeeded())
823  return diag;
824 
825  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
826  forallOp.getDeviceMaskingAttr();
827  assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
828  assert((!*maybeMaskingAttr || useLinearMapping) &&
829  "masking requires linear mapping");
830 
831  // Start mapping.
832  MLIRContext *ctx = forallOp.getContext();
833  gpuIdBuilder =
835  .Case([&](GPUWarpgroupMappingAttr) {
836  return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping,
837  *maybeMaskingAttr);
838  })
839  .Case([&](GPUWarpMappingAttr) {
840  return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping,
841  *maybeMaskingAttr);
842  })
843  .Case([&](GPUThreadMappingAttr) {
844  return GpuThreadIdBuilder(ctx, useLinearMapping, *maybeMaskingAttr);
845  })
846  .Case([&](GPULaneMappingAttr) {
847  return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping,
848  *maybeMaskingAttr);
849  })
850  .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
851  llvm_unreachable("unknown mapping attribute");
852  });
854 }
855 
857  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
858  scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize,
859  bool syncAfterDistribute) {
860 
861  {
862  // GPU-specific verifications. There is no better place to anchor
863  // those right now: the ForallOp is target-independent and the transform
864  // op does not apply to individual ForallOp.
866  verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp);
867  if (!diag.succeeded())
868  return diag;
869  }
870 
871  GpuIdBuilder gpuIdBuilder;
872  {
873  // Try to construct the id builder, if it fails, return.
875  transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
876  if (!diag.succeeded())
877  return diag;
878  }
879 
880  Location loc = forallOp.getLoc();
881  OpBuilder::InsertionGuard g(rewriter);
882  // Insert after to allow for syncthreads after `forall` is erased.
883  rewriter.setInsertionPointAfter(forallOp);
884  ForallRewriteResult rewriteResult;
886  rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
887  if (!diag.succeeded())
888  return diag;
889  // Add a syncthreads if needed. TODO: warpsync
890  if (syncAfterDistribute)
891  BarrierOp::create(rewriter, loc);
892 
894 }
895 
897  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
898  Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
899  bool syncAfterDistribute) {
900  LDBG() << "Start mapNestedForallToThreadsImpl";
901  if (blockDims.size() != 3) {
902  return definiteFailureHelper(transformOp, target,
903  "requires size-3 thread mapping");
904  }
905 
906  // Create an early zero index value for replacements.
907  Location loc = target->getLoc();
908  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
910  WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
912  rewriter, transformOp, forallOp, blockDims, warpSize,
913  syncAfterDistribute);
914  if (diag.isDefiniteFailure())
915  return WalkResult::interrupt();
916  if (diag.succeeded())
917  return WalkResult::skip();
918  return WalkResult::advance();
919  });
920  if (walkResult.wasInterrupted())
921  return diag;
922 
923  // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
924  // Here, the result of mapping determines the available mapping sizes.
925  replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
926  blockDims);
927 
929 }
930 
931 DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
932  transform::TransformRewriter &rewriter, Operation *target,
933  ApplyToEachResultList &results, TransformState &state) {
934  LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
935  auto transformOp = cast<TransformOpInterface>(getOperation());
936 
937  // Basic high-level verifications.
938  if (!gpuLaunch)
939  return emitSilenceableError() << "Given target is not a gpu.launch";
940 
941  // Mapping to block ids.
942  SmallVector<int64_t> blockDims{getBlockDims()};
944  checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
945  blockDims[0], blockDims[1], blockDims[2]);
946  if (diag.isSilenceableFailure()) {
947  diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large";
948  return diag;
949  }
950 
951  // Set the GPU launch configuration for the block dims early, this is not
952  // subject to IR inspection.
953  diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
954  std::nullopt, std::nullopt, blockDims[0], blockDims[1],
955  blockDims[2]);
956 
957  rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
958  diag =
959  mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
960  getWarpSize(), getSyncAfterDistribute());
961 
962  results.push_back(gpuLaunch.getOperation());
963  return diag;
964 }
965 
966 //===----------------------------------------------------------------------===//
967 // Transform op registration
968 //===----------------------------------------------------------------------===//
969 
970 namespace {
971 /// Registers new ops and declares PDL as dependent dialect since the
972 /// additional ops are using PDL types for operands and results.
973 class GPUTransformDialectExtension
975  GPUTransformDialectExtension> {
976 public:
977  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
978 
979  GPUTransformDialectExtension() {
980  declareGeneratedDialect<GPUDialect>();
981  declareGeneratedDialect<amdgpu::AMDGPUDialect>();
982  declareGeneratedDialect<arith::ArithDialect>();
983  declareGeneratedDialect<scf::SCFDialect>();
984  registerTransformOps<
985 #define GET_OP_LIST
986 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
987  >();
988  }
989 };
990 } // namespace
991 
992 #define GET_OP_CLASSES
993 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
994 
996  registry.addExtensions<GPUTransformDialectExtension>();
997 }
static DiagnosedSilenceableFailure checkMappingAttributeTypes(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp)
Check if given mapping attributes are one of the desired attributes.
static std::optional< SmallVector< int64_t > > getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k)
Returns the target vector size for the target operation based on the native vector size specified wit...
static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > availableMappingSizes, ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder)
static DiagnosedSilenceableFailure definiteFailureHelper(std::optional< TransformOpInterface > transformOp, Operation *target, const Twine &message)
static DiagnosedSilenceableFailure checkMappingSpec(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > numParallelIterations, ArrayRef< int64_t > blockOrGridSizes, int factor, bool useLinearMapping=false)
static void replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, OperationOrBlock *parent, Value replacement, ArrayRef< int64_t > availableMappingSizes)
Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
static std::optional< SmallVector< int64_t > > gpuMmaUnrollOrder(vector::ContractionOp contract)
Pick an unrolling order that will allow tensorcore operation to reuse LHS register.
static DiagnosedSilenceableFailure verifyGpuMapping(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp)
static DiagnosedSilenceableFailure getThreadIdBuilder(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > blockSizes, int64_t warpSize, GpuIdBuilder &gpuIdBuilder)
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
Base type for affine expression.
Definition: AffineExpr.h:68
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:428
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerTransformDialectExtension(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
DiagnosedSilenceableFailure findTopLevelForallOp(Operation *target, scf::ForallOp &topLevelForallOp, TransformOpInterface transformOp)
Find the unique top level scf::ForallOp within a given target op.
DiagnosedSilenceableFailure alterGpuLaunch(RewriterBase &rewriter, mlir::gpu::LaunchOp gpuLaunch, TransformOpInterface transformOp, std::optional< int64_t > gridDimX=std::nullopt, std::optional< int64_t > gridDimY=std::nullopt, std::optional< int64_t > gridDimZ=std::nullopt, std::optional< int64_t > blockDimX=std::nullopt, std::optional< int64_t > blockDimY=std::nullopt, std::optional< int64_t > blockDimZ=std::nullopt)
Alter kernel configuration of the given kernel.
DiagnosedSilenceableFailure createGpuLaunch(RewriterBase &rewriter, Location loc, TransformOpInterface transformOp, mlir::gpu::LaunchOp &launchOp, std::optional< int64_t > gridDimX=std::nullopt, std::optional< int64_t > gridDimY=std::nullopt, std::optional< int64_t > gridDimZ=std::nullopt, std::optional< int64_t > blockDimX=std::nullopt, std::optional< int64_t > blockDimY=std::nullopt, std::optional< int64_t > blockDimZ=std::nullopt)
Create an empty-body gpu::LaunchOp using the provided kernel settings and put a terminator within.
DiagnosedSilenceableFailure mapForallToBlocksImpl(RewriterBase &rewriter, TransformOpInterface transformOp, scf::ForallOp forallOp, SmallVectorImpl< int64_t > &gridDims, const GpuIdBuilder &gpuIdBuilder)
Map the top level scf.forall op to GPU blocks.
DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp, std::optional< int64_t > gridDimX, std::optional< int64_t > gridDimY, std::optional< int64_t > gridDimZ, std::optional< int64_t > blockDimX, std::optional< int64_t > blockDimY, std::optional< int64_t > blockDimZ)
Determine if the size of the kernel configuration is supported by the GPU architecture being used.
Definition: Utils.cpp:360
DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, Operation *target, ArrayRef< int64_t > blockDims, int64_t warpSize, bool syncAfterDistribute)
Search scf.forall ops nested under target and map each such op to an explicit GPU implementation alon...
DiagnosedSilenceableFailure mapOneForallToThreadsImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > blockSizes, int64_t warpSize, bool syncAfterDistribute)
Search scf.forall ops nested under target and map each such op to an explicit GPU implementation alon...
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:154
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:149
Include the generated interface declarations.
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime, amdgpu::Chipset chipset)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition: Passes.h:91
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns)
Erase barriers that do not enforce conflicting memory side effects.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Struct to return the result of the rewrite of a forall operation.
SmallVector< Value > mappingIds
SmallVector< int64_t > mappingSizes
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14
Builder for gpu::BlockIdOps used to map scf.forall to blocks.
Definition: Utils.h:81
Helper struct for configuring the rewrite of mapped scf.forall ops to various gpu id configurations.
Definition: Utils.h:60
SmallVector< DeviceMappingAttrInterface > mappingAttributes
The mapping attributes targeted by this generator.
Definition: Utils.h:69
GpuIdBuilderFnType idBuilder
The constructor that builds the concrete IR for mapping ids.
Definition: Utils.h:72
Builder for warp ids used to map scf.forall to reindexed threads.
Definition: Utils.h:120
Builder for warp ids used to map scf.forall to reindexed warps.
Definition: Utils.h:107
Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
Definition: Utils.h:92
Helper type for functions that generate ids for the mapping of a scf.forall.
Definition: Utils.h:31
std::string errorMsg
Error message, if not empty then building the ids failed.
Definition: Utils.h:33
SmallVector< Value > predicateOps
Values used to predicate the forall body when activeMappingSizes is smaller than the available mappin...
Definition: Utils.h:38
SmallVector< Value > mappingIdOps
Values used to replace the forall induction variables.
Definition: Utils.h:35
Options that control the vector unrolling.