59#define GEN_PASS_DEF_ACCCOMPUTELOWERING
60#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
64#define DEBUG_TYPE "acc-compute-lowering"
85static bool isEffectivelySerial(ParallelOp op) {
86 auto numGangs = op.getNumGangsValues();
87 if (numGangs.size() != 1)
89 Value numWorkers = op.getNumWorkersValue();
92 Value vectorLength = op.getVectorLengthValue();
100static bool isOpInComputeRegion(
Operation *op) {
105static bool isOpInSerialRegion(
Operation *op) {
107 return isEffectivelySerial(parallelOp);
109 return computeRegion.isEffectivelySerial();
114 auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
116 if (attr && attr.getLevel().getValue() == ParLevel::seq)
123static void setParDimsAttr(
Operation *op, GPUParallelDimsAttr attr) {
124 op->setAttr(GPUParallelDimsAttr::name, attr);
130 GPUParallelDimAttr parDim) {
131 GPUParallelDimAttr *lb = llvm::lower_bound(
133 [](
const GPUParallelDimAttr &a,
const GPUParallelDimAttr &
b) {
134 return a.getOrder() >
b.getOrder();
136 if (lb == parDims.end() || *lb != parDim)
137 parDims.insert(lb, parDim);
144 DeviceType deviceType) {
146 auto *ctx = loopOp->getContext();
148 if (loopOp.hasVector(deviceType))
149 insertParDim(parDims, policy.
vectorDim(ctx));
150 if (loopOp.hasWorker(deviceType))
151 insertParDim(parDims, policy.
workerDim(ctx));
152 if (
auto gangDimValue = loopOp.getGangValue(GangArgType::Dim, deviceType)) {
153 if (
auto gangDimDefOp =
156 insertParDim(parDims, policy.
gangDim(ctx, gangLevel));
158 }
else if (loopOp.hasGang(deviceType)) {
159 insertParDim(parDims, policy.
gangDim(ctx, ParLevel::gang_dim1));
167template <
typename ComputeConstructT>
169assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
175 auto loc = computeOp->getLoc();
177 auto numGangs = computeOp.getNumGangsValues(deviceType);
178 if (numGangs.empty())
179 numGangs = computeOp.getNumGangsValues();
180 for (
auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
183 ParWidthOp::create(rewriter, loc,
185 rewriter, gangSize.getLoc(), indexTy, gangSize),
186 policy.
gangDim(ctx, gangLevel)));
189 Value numWorkers = computeOp.getNumWorkersValue(deviceType);
191 numWorkers = computeOp.getNumWorkersValue();
193 values.push_back(ParWidthOp::create(
200 Value vectorLength = computeOp.getVectorLengthValue(deviceType);
202 vectorLength = computeOp.getVectorLengthValue();
204 values.push_back(ParWidthOp::create(
207 indexTy, vectorLength),
216assignKnownLaunchArgs<SerialOp>(SerialOp, DeviceType,
RewriterBase &,
228 DeviceType deviceType)
231 LogicalResult matchAndRewrite(LoopOp loopOp,
233 if (loopOp.getUnstructured()) {
238 rewriter.
replaceOp(loopOp, executeRegion);
242 LoopParMode parMode = loopOp.getDefaultOrDeviceTypeParallelism(deviceType);
244 if (parMode == LoopParMode::loop_seq || isOpInSerialRegion(loopOp)) {
251 setParDimsAttr(parallelOp,
252 GPUParallelDimsAttr::seq(loopOp->getContext()));
254 }
else if (parMode == LoopParMode::loop_auto) {
256 assert(!isOpInSerialRegion(loopOp) &&
257 "Expected loop to be in non-serial region");
264 }
else if (!isOpInComputeRegion(loopOp) &&
266 loopOp->getParentOfType<FunctionOpInterface>())) {
275 assert(parMode == LoopParMode::loop_independent &&
276 "Expected loop to be independent");
282 getParallelDimensions(loopOp, policy, deviceType);
283 if (!parDims.empty()) {
285 GPUParallelDimsAttr::get(loopOp->getContext(), parDims);
286 setParDimsAttr(parallelOp, parDimsAttr);
296 DeviceType deviceType;
303template <
typename ComputeConstructT>
307 DeviceType deviceType)
309 deviceType(deviceType) {}
311 LogicalResult matchAndRewrite(ComputeConstructT computeOp,
315 KernelEnvironmentOp::createAndPopulate(computeOp, rewriter);
317 assignKnownLaunchArgs(computeOp, deviceType, rewriter, policy);
318 Region ®ion = computeOp.getRegion();
323 computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(),
324 ComputeConstructT::getOperationName(), region, rewriter, mapping);
325 if (!computeRegion) {
335 DeviceType deviceType;
342class ACCComputeLowering
343 :
public acc::impl::ACCComputeLoweringBase<ACCComputeLowering> {
345 using ACCComputeLoweringBase::ACCComputeLoweringBase;
347 void runOnOperation()
override {
348 auto op = getOperation();
349 auto *context = op.getContext();
357 loopPatterns.
insert<ACCLoopConversion>(context, policy, deviceType);
359 return signalPassFailure();
365 .
insert<ComputeOpConversion<ParallelOp>, ComputeOpConversion<KernelsOp>,
366 ComputeOpConversion<SerialOp>>(context, policy, deviceType);
368 return signalPassFailure();
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
ParDimAttrT vectorDim(MLIRContext *ctx) const
ParDimAttrT workerDim(MLIRContext *ctx) const
ParDimAttrT gangDim(MLIRContext *ctx, ParLevel level) const
Convenience methods for specific parallelism levels.
Default policy that provides the standard GPU mapping: gang(dim:1) -> BlockX (gridDim....
Specialization of arith.constant op that returns an integer value.
ParLevel getGangParLevel(int64_t gangDimValue)
Convert a gang dimension value (1, 2, or 3) to the corresponding ParLevel.
static constexpr StringLiteral getSpecializedRoutineAttrName()
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
scf::ParallelOp convertACCLoopToSCFParallel(LoopOp loopOp, RewriterBase &rewriter)
Convert acc.loop to scf.parallel.
mlir::Operation * getEnclosingComputeOp(mlir::Region ®ion)
Used to obtain the enclosing compute construct operation that contains the provided region.
scf::ExecuteRegionOp convertUnstructuredACCLoopToSCFExecuteRegion(LoopOp loopOp, RewriterBase &rewriter)
Convert an unstructured acc.loop to scf.execute_region.
scf::ForOp convertACCLoopToSCFFor(LoopOp loopOp, RewriterBase &rewriter, bool enableCollapse)
Convert a structured acc.loop to scf.for.
ACCParMappingPolicy< mlir::acc::GPUParallelDimAttr > ACCToGPUMappingPolicy
Type alias for the GPU-specific mapping policy.
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region ®ionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={})
Build an acc.compute_region operation by cloning a source region.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...