MLIR  22.0.0git
GPUTransformOps.cpp
Go to the documentation of this file.
1 //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
31 #include "mlir/IR/AffineExpr.h"
32 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/IRMapping.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/OpDefinition.h"
37 #include "mlir/IR/Visitors.h"
38 #include "mlir/Support/LLVM.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/TypeSwitch.h"
43 #include "llvm/Support/DebugLog.h"
44 #include "llvm/Support/ErrorHandling.h"
45 #include "llvm/Support/InterleavedRange.h"
46 #include "llvm/Support/LogicalResult.h"
47 #include <optional>
48 #include <type_traits>
49 
50 using namespace mlir;
51 using namespace mlir::gpu;
52 using namespace mlir::transform;
53 using namespace mlir::transform::gpu;
54 
55 #define DEBUG_TYPE "gpu-transforms"
56 
57 //===----------------------------------------------------------------------===//
58 // Apply...ConversionPatternsOp
59 //===----------------------------------------------------------------------===//
60 
61 void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
62  TypeConverter &typeConverter, RewritePatternSet &patterns) {
63  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
64  // NVVM uses alloca in the default address space to represent private
65  // memory allocations, so drop private annotations. NVVM uses address
66  // space 3 for shared memory. NVVM uses the default address space to
67  // represent global memory.
68  // Used in populateGpuToNVVMConversionPatternsso attaching here for now.
69  // TODO: We should have a single to_nvvm_type_converter.
71  llvmTypeConverter, [](AddressSpace space) -> unsigned {
72  switch (space) {
73  case AddressSpace::Global:
74  return static_cast<unsigned>(
76  case AddressSpace::Workgroup:
77  return static_cast<unsigned>(
79  case AddressSpace::Private:
80  return 0;
81  }
82  llvm_unreachable("unknown address space enum value");
83  return 0;
84  });
85  // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now.
86  // TODO: We should have a single to_nvvm_type_converter.
87  llvmTypeConverter.addConversion(
88  [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); });
89  // Set higher benefit, so patterns will run before generic LLVM lowering.
91  getBenefit());
92 }
93 
94 LogicalResult
95 transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
96  transform::TypeConverterBuilderOpInterface builder) {
97  if (builder.getTypeConverterType() != "LLVMTypeConverter")
98  return emitOpError("expected LLVMTypeConverter");
99  return success();
100 }
101 
102 void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
103  TypeConverter &typeConverter, RewritePatternSet &patterns) {
104  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
106 }
107 
108 LogicalResult
109 transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
110  transform::TypeConverterBuilderOpInterface builder) {
111  if (builder.getTypeConverterType() != "LLVMTypeConverter")
112  return emitOpError("expected LLVMTypeConverter");
113  return success();
114 }
115 
116 void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
117  populatePatterns(TypeConverter &typeConverter,
119  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
121 }
122 
123 LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
124  verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
125  if (builder.getTypeConverterType() != "LLVMTypeConverter")
126  return emitOpError("expected LLVMTypeConverter");
127  return success();
128 }
129 
130 void transform::ApplyGPUToROCDLConversionPatternsOp::populatePatterns(
131  TypeConverter &typeConverter, RewritePatternSet &patterns) {
132  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
134  llvmTypeConverter, [](AddressSpace space) {
135  switch (space) {
136  case AddressSpace::Global:
137  return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
138  case AddressSpace::Workgroup:
139  return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
140  case AddressSpace::Private:
141  return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
142  }
143  llvm_unreachable("unknown address space enum value");
144  });
145  FailureOr<amdgpu::Chipset> maybeChipset =
146  amdgpu::Chipset::parse(getChipset());
147  assert(llvm::succeeded(maybeChipset) && "expected valid chipset");
149  llvmTypeConverter, patterns, mlir::gpu::amd::Runtime::HIP, *maybeChipset);
150 }
151 
152 LogicalResult
153 transform::ApplyGPUToROCDLConversionPatternsOp::verifyTypeConverter(
154  transform::TypeConverterBuilderOpInterface builder) {
155  FailureOr<amdgpu::Chipset> maybeChipset =
156  amdgpu::Chipset::parse(getChipset());
157  if (failed(maybeChipset)) {
158  return emitOpError("Invalid chipset name: " + getChipset());
159  }
160  if (builder.getTypeConverterType() != "LLVMTypeConverter")
161  return emitOpError("expected LLVMTypeConverter");
162  return success();
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // Apply...PatternsOp
167 //===----------------------------------------------------------------------===//s
168 
169 void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
171 }
172 
173 void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
175  std::optional<StringRef> chipsetName = getChipset();
176  std::optional<amdgpu::Chipset> maybeChipset;
177  if (chipsetName) {
178  FailureOr<amdgpu::Chipset> parsedChipset =
179  amdgpu::Chipset::parse(*chipsetName);
180  assert(llvm::succeeded(parsedChipset) && "expected valid chipset");
181  maybeChipset = parsedChipset;
182  }
183 
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // ApplyUnrollVectorsSubgroupMmaOp
189 //===----------------------------------------------------------------------===//
190 
191 /// Pick an unrolling order that will allow tensorcore operation to reuse LHS
192 /// register.
193 static std::optional<SmallVector<int64_t>>
194 gpuMmaUnrollOrder(vector::ContractionOp contract) {
195  SmallVector<int64_t> order;
196  // First make reduction the outer dimensions.
197  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
198  if (vector::isReductionIterator(iter)) {
199  order.push_back(index);
200  }
201  }
202 
203  llvm::SmallDenseSet<int64_t> dims;
204  for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
205  dims.insert(cast<AffineDimExpr>(expr).getPosition());
206  }
207  // Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
208  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
209  if (vector::isParallelIterator(iter) && dims.count(index)) {
210  order.push_back(index);
211  }
212  }
213  // Then the remaining parallel loops.
214  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
215  if (vector::isParallelIterator(iter) && !dims.count(index)) {
216  order.push_back(index);
217  }
218  }
219  return order;
220 }
221 
222 /// Returns the target vector size for the target operation based on the native
223 /// vector size specified with `m`, `n`, and `k`.
224 static std::optional<SmallVector<int64_t>>
225 getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
226  if (auto contract = dyn_cast<vector::ContractionOp>(op)) {
227  int64_t contractRank = contract.getIteratorTypes().size();
228  if (contractRank < 3)
229  return std::nullopt;
230  SmallVector<int64_t> nativeSize(contractRank - 3, 1);
231  nativeSize.append({m, n, k});
232  return nativeSize;
233  }
234  if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
235  int64_t writeRank = writeOp.getVectorType().getRank();
236  if (writeRank < 2)
237  return std::nullopt;
238  SmallVector<int64_t> nativeSize(writeRank - 2, 1);
239  nativeSize.append({m, n});
240  return nativeSize;
241  }
242  if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
243  // Transfer read ops may need different shapes based on how they are being
244  // used. For simplicity just match the shape used by the extract strided op.
245  VectorType sliceType;
246  for (Operation *users : op->getUsers()) {
247  auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
248  if (!extract)
249  return std::nullopt;
250  auto vecType = cast<VectorType>(extract.getResult().getType());
251  if (sliceType && sliceType != vecType)
252  return std::nullopt;
253  sliceType = vecType;
254  }
255  return llvm::to_vector(sliceType.getShape());
256  }
257  if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
258  if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
259  // TODO: The condition for unrolling elementwise should be restricted
260  // only to operations that need unrolling (connected to the contract).
261  if (vecType.getRank() < 2)
262  return std::nullopt;
263 
264  // First check whether there is a slice to infer the shape from. This is
265  // required for cases where the accumulator type differs from the input
266  // types, in which case we will see an `arith.ext_` between the contract
267  // and transfer_read which needs to be unrolled.
268  VectorType sliceType;
269  for (Operation *users : op->getUsers()) {
270  auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
271  if (!extract)
272  return std::nullopt;
273  auto vecType = cast<VectorType>(extract.getResult().getType());
274  if (sliceType && sliceType != vecType)
275  return std::nullopt;
276  sliceType = vecType;
277  }
278  if (sliceType)
279  return llvm::to_vector(sliceType.getShape());
280 
281  // Else unroll for trailing elementwise.
282  SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1);
283  // Map elementwise ops to the output shape.
284  nativeSize.append({m, n});
285  return nativeSize;
286  }
287  }
288  return std::nullopt;
289 }
290 
291 void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
293  auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
294  auto contract = dyn_cast<vector::ContractionOp>(op);
295  if (!contract)
296  return std::nullopt;
297  return gpuMmaUnrollOrder(contract);
298  };
299 
300  int64_t m = getM();
301  int64_t n = getN();
302  int64_t k = getK();
303  auto nativeShapeFn =
304  [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> {
305  return getSubgroupMmaNativeVectorSize(op, m, n, k);
306  };
307  vector::populateVectorUnrollPatterns(
309  .setNativeShapeFn(nativeShapeFn)
310  .setUnrollTraversalOrderFn(unrollOrder));
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // EliminateBarriersOp
315 //===----------------------------------------------------------------------===//
316 
317 void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) {
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // Block and thread mapping utilities.
323 //===----------------------------------------------------------------------===//
324 
325 namespace {
326 /// Local types used for mapping verification.
327 struct MappingKind {};
328 struct BlockMappingKind : MappingKind {};
329 struct ThreadMappingKind : MappingKind {};
330 } // namespace
331 
333 definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
334  Operation *target, const Twine &message) {
335  if (transformOp.has_value())
336  return transformOp->emitDefiniteFailure() << message;
337  return emitDefiniteFailure(target, message);
338 }
339 
340 /// Check if given mapping attributes are one of the desired attributes
341 template <typename MappingKindType>
343 checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
344  scf::ForallOp forallOp) {
345  if (!forallOp.getMapping().has_value()) {
346  return definiteFailureHelper(transformOp, forallOp,
347  "scf.forall op requires a mapping attribute");
348  }
349 
350  bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
351  llvm::IsaPred<GPUBlockMappingAttr>);
352  bool hasWarpgroupMapping = llvm::any_of(
353  forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
354  bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
355  llvm::IsaPred<GPUWarpMappingAttr>);
356  bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
357  llvm::IsaPred<GPUThreadMappingAttr>);
358  bool hasLaneMapping = llvm::any_of(forallOp.getMapping().value(),
359  llvm::IsaPred<GPULaneMappingAttr>);
360  int64_t countMappingTypes = 0;
361  countMappingTypes += hasBlockMapping ? 1 : 0;
362  countMappingTypes += hasWarpgroupMapping ? 1 : 0;
363  countMappingTypes += hasWarpMapping ? 1 : 0;
364  countMappingTypes += hasThreadMapping ? 1 : 0;
365  countMappingTypes += hasLaneMapping ? 1 : 0;
366  if (countMappingTypes > 1) {
367  return definiteFailureHelper(
368  transformOp, forallOp,
369  "cannot mix different mapping types, use nesting");
370  }
371  if (std::is_same<MappingKindType, BlockMappingKind>::value &&
372  !hasBlockMapping) {
373  return definiteFailureHelper(
374  transformOp, forallOp,
375  "scf.forall op requires a mapping attribute of kind 'block'");
376  }
377  if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
378  !hasLaneMapping && !hasThreadMapping && !hasWarpMapping &&
379  !hasWarpgroupMapping) {
380  return definiteFailureHelper(transformOp, forallOp,
381  "scf.forall op requires a mapping attribute "
382  "of kind 'thread' or 'warp'");
383  }
384 
385  DenseSet<Attribute> seen;
386  for (Attribute map : forallOp.getMapping()->getValue()) {
387  if (seen.contains(map)) {
388  return definiteFailureHelper(
389  transformOp, forallOp,
390  "duplicate attribute, cannot map different loops "
391  "to the same mapping id");
392  }
393  seen.insert(map);
394  }
395 
396  auto isLinear = [](DeviceMappingAttrInterface attr) {
397  return attr.isLinearMapping();
398  };
399  if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) &&
400  !llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) {
401  return definiteFailureHelper(
402  transformOp, forallOp,
403  "cannot mix linear and non-linear mapping modes");
404  }
405 
406  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
407  forallOp.getDeviceMaskingAttr();
408  if (succeeded(maybeMaskingAttr) && *maybeMaskingAttr &&
409  !forallOp.usesLinearMapping()) {
410  return definiteFailureHelper(
411  transformOp, forallOp,
412  "device masking is only available in linear mapping mode");
413  }
414 
416 }
417 
418 template <typename MappingKindType>
420 verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
421  scf::ForallOp forallOp) {
422  // Check the types of the mapping attributes match.
424  checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp);
425  if (!typeRes.succeeded())
426  return typeRes;
427 
428  // Perform other non-types verifications.
429  if (!forallOp.isNormalized())
430  return definiteFailureHelper(transformOp, forallOp,
431  "unsupported non-normalized loops");
432  if (forallOp.getNumResults() > 0)
433  return definiteFailureHelper(transformOp, forallOp,
434  "only bufferized scf.forall can be mapped");
435  bool useLinearMapping = forallOp.usesLinearMapping();
436  // TODO: This would be more natural with support for Optional<EnumParameter>
437  // in GPUDeviceMappingAttr.
438  int64_t maxNumMappingsSupported =
439  useLinearMapping ? (getMaxEnumValForMappingId() -
440  static_cast<uint64_t>(MappingId::DimZ))
441  : 3;
442  if (forallOp.getRank() > maxNumMappingsSupported) {
443  return definiteFailureHelper(transformOp, forallOp,
444  "scf.forall with rank > ")
445  << maxNumMappingsSupported
446  << " does not lower for the specified mapping attribute type";
447  }
448  auto numParallelIterations =
449  getConstantIntValues(forallOp.getMixedUpperBound());
450  if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
451  return definiteFailureHelper(
452  transformOp, forallOp,
453  "requires statically sized, normalized forall op");
454  }
456 }
457 
458 /// Struct to return the result of the rewrite of a forall operation.
462 };
463 
464 /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
465 template <typename OpTy, typename OperationOrBlock>
466 static void
468  OperationOrBlock *parent, Value replacement,
469  ArrayRef<int64_t> availableMappingSizes) {
470  parent->walk([&](OpTy idOp) {
471  if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
472  rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
473  });
474 }
475 
477  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
478  scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
479  ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) {
480  LDBG() << "--start rewriteOneForallCommonImpl";
481 
482  // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
483  auto numParallelIterations =
484  getConstantIntValues(forallOp.getMixedUpperBound());
485  assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
486  "requires statically sized, normalized forall op");
487  SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
488  SmallVector<DeviceMappingAttrInterface> forallMappingAttrsVec =
489  forallOp.getDeviceMappingAttrs();
490  SetVector<Attribute> forallMappingAttrs;
491  forallMappingAttrs.insert_range(forallMappingAttrsVec);
492  auto comparator = [](Attribute a, Attribute b) -> bool {
493  return cast<DeviceMappingAttrInterface>(a).getMappingId() <
494  cast<DeviceMappingAttrInterface>(b).getMappingId();
495  };
496 
497  // Step 1.b. In the linear case, compute the max mapping to avoid needlessly
498  // mapping all dimensions. In the 3-D mapping case we need to map all
499  // dimensions.
500  DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>(
501  *llvm::max_element(forallMappingAttrs, comparator));
502  DeviceMappingAttrInterface maxLinearMapping;
503  if (maxMapping.isLinearMapping())
504  maxLinearMapping = maxMapping;
505  for (auto attr : gpuIdBuilder.mappingAttributes) {
506  // If attr overflows, just skip.
507  if (maxLinearMapping && comparator(maxLinearMapping, attr))
508  continue;
509  // Try to insert. If element was already present, just continue.
510  if (!forallMappingAttrs.insert(attr))
511  continue;
512  // Otherwise, we have a new insertion without a size -> use size 1.
513  tmpMappingSizes.push_back(1);
514  }
515  LDBG() << "----tmpMappingSizes extracted from scf.forall op: "
516  << llvm::interleaved(tmpMappingSizes);
517 
518  // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
519  SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
520  forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
521  LDBG() << "----forallMappingSizes: " << llvm::interleaved(forallMappingSizes);
522  LDBG() << "----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs);
523 
524  // Step 3. Generate the mappingIdOps using the provided generator.
525  Location loc = forallOp.getLoc();
526  OpBuilder::InsertionGuard guard(rewriter);
527  rewriter.setInsertionPoint(forallOp);
528  SmallVector<int64_t> originalBasis(availableMappingSizes);
529  bool originalBasisWasProvided = !originalBasis.empty();
530  if (!originalBasisWasProvided) {
531  LDBG() << "----originalBasis was not provided, deriving it and there will "
532  "be no "
533  "predication";
534  originalBasis = forallMappingSizes;
535  while (originalBasis.size() < 3)
536  originalBasis.push_back(1);
537  } else {
538  LDBG() << "----originalBasis was provided, using it, there will be "
539  "predication";
540  }
541  LDBG() << "------originalBasis: " << llvm::interleaved(originalBasis);
542 
543  IdBuilderResult builderResult =
544  gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
545  if (!builderResult.errorMsg.empty())
546  return definiteFailureHelper(transformOp, forallOp, builderResult.errorMsg);
547 
548  LDBG() << builderResult;
549 
550  // Step 4. Map the induction variables to the mappingIdOps, this may involve
551  // a permutation.
552  SmallVector<Value> mappingIdOps = builderResult.mappingIdOps;
553  IRMapping bvm;
554  for (auto [iv, dim] : llvm::zip_equal(
555  forallOp.getInductionVars(),
556  forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
557  auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
558  Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
559  LDBG() << "----map: " << iv << " to " << peIdOp;
560  bvm.map(iv, peIdOp);
561  }
562 
563  // Step 5. If the originalBasis is already known, create conditionals to
564  // predicate the region. Otherwise, the current forall determines the
565  // originalBasis and no predication occurs.
566  Value predicate;
567  if (originalBasisWasProvided) {
568  for (Value tmpPredicate : builderResult.predicateOps) {
569  predicate = predicate ? arith::AndIOp::create(rewriter, loc, predicate,
570  tmpPredicate)
571  : tmpPredicate;
572  }
573  }
574 
575  // Step 6. Move the body of forallOp.
576  // Erase the terminator first, it will not be used.
577  rewriter.eraseOp(forallOp.getTerminator());
578  Block *targetBlock;
579  Block::iterator insertionPoint;
580  if (predicate) {
581  // Step 6.a. If predicated, move at the beginning.
582  auto ifOp = scf::IfOp::create(rewriter, loc, predicate,
583  /*withElseRegion=*/false);
584  targetBlock = ifOp.thenBlock();
585  insertionPoint = ifOp.thenBlock()->begin();
586  } else {
587  // Step 6.b. Otherwise, move inline just at the rewriter insertion
588  // point.
589  targetBlock = forallOp->getBlock();
590  insertionPoint = rewriter.getInsertionPoint();
591  }
592  Block &sourceBlock = forallOp.getRegion().front();
593  targetBlock->getOperations().splice(insertionPoint,
594  sourceBlock.getOperations());
595 
596  // Step 7. RAUW indices.
597  for (Value loopIndex : forallOp.getInductionVars()) {
598  Value threadIdx = bvm.lookup(loopIndex);
599  rewriter.replaceAllUsesWith(loopIndex, threadIdx);
600  }
601 
602  // Step 8. Erase old op.
603  rewriter.eraseOp(forallOp);
604 
605  LDBG() << "----result forallMappingSizes: "
606  << llvm::interleaved(forallMappingSizes);
607  LDBG() << "----result mappingIdOps: " << llvm::interleaved(mappingIdOps);
608 
609  result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
611 }
612 
613 //===----------------------------------------------------------------------===//
614 // MapForallToBlocks
615 //===----------------------------------------------------------------------===//
616 
618  RewriterBase &rewriter, TransformOpInterface transformOp,
619  scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
620  const GpuIdBuilder &gpuIdBuilder) {
621  LDBG() << "Start mapForallToBlocksImpl";
622 
623  {
624  // GPU-specific verifications. There is no better place to anchor
625  // those right now: the ForallOp is target-independent and the transform
626  // op does not apply to individual ForallOp.
628  verifyGpuMapping<BlockMappingKind>(transformOp, forallOp);
629  if (!diag.succeeded())
630  return diag;
631  }
632 
633  Location loc = forallOp.getLoc();
634  Block *parentBlock = forallOp->getBlock();
635  Value zero;
636  {
637  // Create an early zero index value for replacements and immediately reset
638  // the insertion point.
639  OpBuilder::InsertionGuard guard(rewriter);
640  rewriter.setInsertionPointToStart(parentBlock);
641  zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
642  }
643 
644  ForallRewriteResult rewriteResult;
646  rewriter, transformOp, forallOp,
647  /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder);
648 
649  // Return if anything goes wrong, use silenceable failure as a match
650  // failure.
651  if (!diag.succeeded())
652  return diag;
653 
654  // If gridDims was not provided already, set it from the return.
655  if (gridDims.empty()) {
656  gridDims = rewriteResult.mappingSizes;
657  while (gridDims.size() < 3)
658  gridDims.push_back(1);
659  }
660  assert(gridDims.size() == 3 && "Need 3-D gridDims");
661 
662  // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
663  // Here, the result of mapping determines the available mapping sizes.
664  replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
665  rewriteResult.mappingSizes);
666 
668 }
669 
672  scf::ForallOp &topLevelForallOp,
673  TransformOpInterface transformOp) {
674  auto walkResult = target->walk([&](scf::ForallOp forallOp) {
675  if (forallOp->getParentOfType<scf::ForallOp>())
676  return WalkResult::advance();
677  if (topLevelForallOp)
678  // TODO: Handle multiple forall if they are independent.
679  return WalkResult::interrupt();
680  topLevelForallOp = forallOp;
681  return WalkResult::advance();
682  });
683 
684  if (walkResult.wasInterrupted() || !topLevelForallOp)
685  return transformOp.emitSilenceableError()
686  << "could not find a unique topLevel scf.forall";
688 }
689 
690 DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
691  transform::TransformRewriter &rewriter, Operation *target,
693  LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
694  auto transformOp = cast<TransformOpInterface>(getOperation());
695 
696  if (!getGenerateGpuLaunch() && !gpuLaunch) {
698  emitSilenceableError()
699  << "Given target is not gpu.launch, set `generate_gpu_launch` "
700  "attribute";
701  diag.attachNote(target->getLoc()) << "when applied to this payload op";
702  return diag;
703  }
704 
705  scf::ForallOp topLevelForallOp;
707  target, topLevelForallOp, transformOp);
708  if (!diag.succeeded()) {
709  diag.attachNote(target->getLoc()) << "when applied to this payload op";
710  return diag;
711  }
712  assert(topLevelForallOp && "expect an scf.forall");
713 
714  SmallVector<int64_t> gridDims{getGridDims()};
715  if (!getGenerateGpuLaunch() && gridDims.size() != 3)
716  return transformOp.emitDefiniteFailure("transform require size-3 mapping");
717 
718  OpBuilder::InsertionGuard guard(rewriter);
719  rewriter.setInsertionPoint(topLevelForallOp);
720 
721  // Generate gpu launch here and move the forall inside
722  if (getGenerateGpuLaunch()) {
724  createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
725  if (!diag.succeeded())
726  return diag;
727 
728  rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
729  Operation *newForallOp = rewriter.clone(*topLevelForallOp);
730  rewriter.eraseOp(topLevelForallOp);
731  topLevelForallOp = cast<scf::ForallOp>(newForallOp);
732  }
733 
734  // The BlockIdBuilder adapts to whatever is thrown at it.
735  bool useLinearMapping = false;
736  if (topLevelForallOp.getMapping())
737  useLinearMapping = topLevelForallOp.usesLinearMapping();
738 
739  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
740  topLevelForallOp.getDeviceMaskingAttr();
741  assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
742  assert((!*maybeMaskingAttr || useLinearMapping) &&
743  "masking requires linear mapping");
744 
745  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping,
746  *maybeMaskingAttr);
747 
749  rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
750  if (!diag.succeeded())
751  return diag;
752 
753  // Set the GPU launch configuration for the grid dims late, this is
754  // subject to IR inspection.
755  diag = alterGpuLaunch(rewriter, gpuLaunch,
756  cast<TransformOpInterface>(getOperation()), gridDims[0],
757  gridDims[1], gridDims[2]);
758 
759  results.push_back(gpuLaunch);
760  return diag;
761 }
762 
763 LogicalResult transform::MapForallToBlocks::verify() {
764  if (!getGridDims().empty() && getGridDims().size() != 3) {
765  return emitOpError() << "transform requires empty or size-3 grid_dims";
766  }
767  return success();
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // MapNestedForallToThreads
772 //===----------------------------------------------------------------------===//
773 
775  std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
776  ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
777  int factor, bool useLinearMapping = false) {
778  if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
780  transformOp, forallOp,
781  Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
782  Twine(factor));
783  return diag;
784  }
785  if (computeProduct(numParallelIterations) * factor >
786  computeProduct(blockOrGridSizes)) {
788  transformOp, forallOp,
789  Twine("the number of required parallel resources (blocks or "
790  "threads) ") +
791  Twine(computeProduct(numParallelIterations) * factor) +
792  " overflows the number of available resources " +
793  Twine(computeProduct(blockOrGridSizes)));
794  return diag;
795  }
797 }
798 
800 getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
801  scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
802  int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
803  DeviceMappingAttrInterface mappingAttr =
804  forallOp.getDeviceMappingAttrs().front();
805  bool useLinearMapping = mappingAttr.isLinearMapping();
806 
807  // Sanity checks that may result in runtime verification errors.
808  auto numParallelIterations =
809  getConstantIntValues((forallOp.getMixedUpperBound()));
810  if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
811  return definiteFailureHelper(
812  transformOp, forallOp,
813  "requires statically sized, normalized forall op");
814  }
815  int64_t factor = 1;
816  if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
817  factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
818  } else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
819  factor = warpSize;
820  }
822  checkMappingSpec(transformOp, forallOp, numParallelIterations.value(),
823  blockSizes, factor, useLinearMapping);
824  if (!diag.succeeded())
825  return diag;
826 
827  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
828  forallOp.getDeviceMaskingAttr();
829  assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
830  assert((!*maybeMaskingAttr || useLinearMapping) &&
831  "masking requires linear mapping");
832 
833  // Start mapping.
834  MLIRContext *ctx = forallOp.getContext();
835  gpuIdBuilder =
837  .Case([&](GPUWarpgroupMappingAttr) {
838  return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping,
839  *maybeMaskingAttr);
840  })
841  .Case([&](GPUWarpMappingAttr) {
842  return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping,
843  *maybeMaskingAttr);
844  })
845  .Case([&](GPUThreadMappingAttr) {
846  return GpuThreadIdBuilder(ctx, useLinearMapping, *maybeMaskingAttr);
847  })
848  .Case([&](GPULaneMappingAttr) {
849  return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping,
850  *maybeMaskingAttr);
851  })
852  .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
853  llvm_unreachable("unknown mapping attribute");
854  });
856 }
857 
859  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
860  scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize,
861  bool syncAfterDistribute) {
862 
863  {
864  // GPU-specific verifications. There is no better place to anchor
865  // those right now: the ForallOp is target-independent and the transform
866  // op does not apply to individual ForallOp.
868  verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp);
869  if (!diag.succeeded())
870  return diag;
871  }
872 
873  GpuIdBuilder gpuIdBuilder;
874  {
875  // Try to construct the id builder, if it fails, return.
877  transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
878  if (!diag.succeeded())
879  return diag;
880  }
881 
882  Location loc = forallOp.getLoc();
883  OpBuilder::InsertionGuard g(rewriter);
884  // Insert after to allow for syncthreads after `forall` is erased.
885  rewriter.setInsertionPointAfter(forallOp);
886  ForallRewriteResult rewriteResult;
888  rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
889  if (!diag.succeeded())
890  return diag;
891  // Add a syncthreads if needed. TODO: warpsync
892  if (syncAfterDistribute)
893  BarrierOp::create(rewriter, loc);
894 
896 }
897 
899  RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
900  Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
901  bool syncAfterDistribute) {
902  LDBG() << "Start mapNestedForallToThreadsImpl";
903  if (blockDims.size() != 3) {
904  return definiteFailureHelper(transformOp, target,
905  "requires size-3 thread mapping");
906  }
907 
908  // Create an early zero index value for replacements.
909  Location loc = target->getLoc();
910  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
912  WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
914  rewriter, transformOp, forallOp, blockDims, warpSize,
915  syncAfterDistribute);
916  if (diag.isDefiniteFailure())
917  return WalkResult::interrupt();
918  if (diag.succeeded())
919  return WalkResult::skip();
920  return WalkResult::advance();
921  });
922  if (walkResult.wasInterrupted())
923  return diag;
924 
925  // Replace ids of dimensions known to be 1 by 0 to simplify the IR.
926  // Here, the result of mapping determines the available mapping sizes.
927  replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
928  blockDims);
929 
931 }
932 
933 DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
934  transform::TransformRewriter &rewriter, Operation *target,
935  ApplyToEachResultList &results, TransformState &state) {
936  LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
937  auto transformOp = cast<TransformOpInterface>(getOperation());
938 
939  // Basic high-level verifications.
940  if (!gpuLaunch)
941  return emitSilenceableError() << "Given target is not a gpu.launch";
942 
943  // Mapping to block ids.
944  SmallVector<int64_t> blockDims{getBlockDims()};
946  checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
947  blockDims[0], blockDims[1], blockDims[2]);
948  if (diag.isSilenceableFailure()) {
949  diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large";
950  return diag;
951  }
952 
953  // Set the GPU launch configuration for the block dims early, this is not
954  // subject to IR inspection.
955  diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
956  std::nullopt, std::nullopt, blockDims[0], blockDims[1],
957  blockDims[2]);
958 
959  rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
960  diag =
961  mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims,
962  getWarpSize(), getSyncAfterDistribute());
963 
964  results.push_back(gpuLaunch.getOperation());
965  return diag;
966 }
967 
968 //===----------------------------------------------------------------------===//
969 // Transform op registration
970 //===----------------------------------------------------------------------===//
971 
972 namespace {
973 /// Registers new ops and declares PDL as dependent dialect since the
974 /// additional ops are using PDL types for operands and results.
975 class GPUTransformDialectExtension
977  GPUTransformDialectExtension> {
978 public:
979  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
980 
981  GPUTransformDialectExtension() {
982  declareGeneratedDialect<GPUDialect>();
983  declareGeneratedDialect<amdgpu::AMDGPUDialect>();
984  declareGeneratedDialect<arith::ArithDialect>();
985  declareGeneratedDialect<scf::SCFDialect>();
986  registerTransformOps<
987 #define GET_OP_LIST
988 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
989  >();
990  }
991 };
992 } // namespace
993 
994 #define GET_OP_CLASSES
995 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
996 
998  registry.addExtensions<GPUTransformDialectExtension>();
999 }
static constexpr int64_t kSharedMemorySpace
static DiagnosedSilenceableFailure checkMappingAttributeTypes(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp)
Check if given mapping attributes are one of the desired attributes.
static std::optional< SmallVector< int64_t > > getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k)
Returns the target vector size for the target operation based on the native vector size specified wit...
static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > availableMappingSizes, ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder)
static DiagnosedSilenceableFailure definiteFailureHelper(std::optional< TransformOpInterface > transformOp, Operation *target, const Twine &message)
static DiagnosedSilenceableFailure checkMappingSpec(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > numParallelIterations, ArrayRef< int64_t > blockOrGridSizes, int factor, bool useLinearMapping=false)
static void replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, OperationOrBlock *parent, Value replacement, ArrayRef< int64_t > availableMappingSizes)
Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR.
static std::optional< SmallVector< int64_t > > gpuMmaUnrollOrder(vector::ContractionOp contract)
Pick an unrolling order that will allow tensorcore operation to reuse LHS register.
static DiagnosedSilenceableFailure verifyGpuMapping(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp)
static DiagnosedSilenceableFailure getThreadIdBuilder(std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > blockSizes, int64_t warpSize, GpuIdBuilder &gpuIdBuilder)
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
Base type for affine expression.
Definition: AffineExpr.h:68
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:428
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:42
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerTransformDialectExtension(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
DiagnosedSilenceableFailure findTopLevelForallOp(Operation *target, scf::ForallOp &topLevelForallOp, TransformOpInterface transformOp)
Find the unique top level scf::ForallOp within a given target op.
DiagnosedSilenceableFailure alterGpuLaunch(RewriterBase &rewriter, mlir::gpu::LaunchOp gpuLaunch, TransformOpInterface transformOp, std::optional< int64_t > gridDimX=std::nullopt, std::optional< int64_t > gridDimY=std::nullopt, std::optional< int64_t > gridDimZ=std::nullopt, std::optional< int64_t > blockDimX=std::nullopt, std::optional< int64_t > blockDimY=std::nullopt, std::optional< int64_t > blockDimZ=std::nullopt)
Alter kernel configuration of the given kernel.
DiagnosedSilenceableFailure createGpuLaunch(RewriterBase &rewriter, Location loc, TransformOpInterface transformOp, mlir::gpu::LaunchOp &launchOp, std::optional< int64_t > gridDimX=std::nullopt, std::optional< int64_t > gridDimY=std::nullopt, std::optional< int64_t > gridDimZ=std::nullopt, std::optional< int64_t > blockDimX=std::nullopt, std::optional< int64_t > blockDimY=std::nullopt, std::optional< int64_t > blockDimZ=std::nullopt)
Create an empty-body gpu::LaunchOp using the provided kernel settings and put a terminator within.
DiagnosedSilenceableFailure mapForallToBlocksImpl(RewriterBase &rewriter, TransformOpInterface transformOp, scf::ForallOp forallOp, SmallVectorImpl< int64_t > &gridDims, const GpuIdBuilder &gpuIdBuilder)
Map the top level scf.forall op to GPU blocks.
DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp, std::optional< int64_t > gridDimX, std::optional< int64_t > gridDimY, std::optional< int64_t > gridDimZ, std::optional< int64_t > blockDimX, std::optional< int64_t > blockDimY, std::optional< int64_t > blockDimZ)
Determine if the size of the kernel configuration is supported by the GPU architecture being used.
Definition: Utils.cpp:360
DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, Operation *target, ArrayRef< int64_t > blockDims, int64_t warpSize, bool syncAfterDistribute)
Search scf.forall ops nested under target and map each such op to an explicit GPU implementation alon...
DiagnosedSilenceableFailure mapOneForallToThreadsImpl(RewriterBase &rewriter, std::optional< TransformOpInterface > transformOp, scf::ForallOp forallOp, ArrayRef< int64_t > blockSizes, int64_t warpSize, bool syncAfterDistribute)
Search scf.forall ops nested under target and map each such op to an explicit GPU implementation alon...
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:154
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:149
Include the generated interface declarations.
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime, amdgpu::Chipset chipset)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
Definition: Passes.h:91
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns)
Erase barriers that do not enforce conflicting memory side effects.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Struct to return the result of the rewrite of a forall operation.
SmallVector< Value > mappingIds
SmallVector< int64_t > mappingSizes
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14
Builder for gpu::BlockIdOps used to map scf.forall to blocks.
Definition: Utils.h:81
Helper struct for configuring the rewrite of mapped scf.forall ops to various gpu id configurations.
Definition: Utils.h:60
SmallVector< DeviceMappingAttrInterface > mappingAttributes
The mapping attributes targeted by this generator.
Definition: Utils.h:69
GpuIdBuilderFnType idBuilder
The constructor that builds the concrete IR for mapping ids.
Definition: Utils.h:72
Builder for warp ids used to map scf.forall to reindexed threads.
Definition: Utils.h:120
Builder for warp ids used to map scf.forall to reindexed warps.
Definition: Utils.h:107
Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
Definition: Utils.h:92
Helper type for functions that generate ids for the mapping of a scf.forall.
Definition: Utils.h:31
std::string errorMsg
Error message, if not empty then building the ids failed.
Definition: Utils.h:33
SmallVector< Value > predicateOps
Values used to predicate the forall body when activeMappingSizes is smaller than the available mappin...
Definition: Utils.h:38
SmallVector< Value > mappingIdOps
Values used to replace the forall induction variables.
Definition: Utils.h:35
Options that control the vector unrolling.