62#define GEN_PASS_DEF_ACCCOMPUTELOWERING
63#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
67#define DEBUG_TYPE "acc-compute-lowering"
88static bool isEffectivelySerial(ParallelOp op) {
89 auto numGangs = op.getNumGangsValues();
90 if (numGangs.size() != 1)
92 Value numWorkers = op.getNumWorkersValue();
95 Value vectorLength = op.getVectorLengthValue();
103static bool isOpInComputeRegion(
Operation *op) {
108static bool isOpInSerialRegion(
Operation *op) {
110 return isEffectivelySerial(parallelOp);
112 return computeRegion.isEffectivelySerial();
117 auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
119 if (attr && attr.getLevel().getValue() == ParLevel::seq)
126static void setParDimsAttr(
Operation *op, GPUParallelDimsAttr attr) {
127 op->setAttr(GPUParallelDimsAttr::name, attr);
133static void materializeConstantLiveInsIntoRegion(
Region ®ion,
137 for (
Value v : liveInValues) {
143 "constants must have a single result");
144 constantLiveIns.push_back(v);
147 if (constantLiveIns.empty())
153 for (
Value v : constantLiveIns) {
156 liveInValues.remove(v);
163 GPUParallelDimAttr parDim) {
164 GPUParallelDimAttr *lb = llvm::lower_bound(
166 [](
const GPUParallelDimAttr &a,
const GPUParallelDimAttr &
b) {
167 return a.getOrder() >
b.getOrder();
169 if (lb == parDims.end() || *lb != parDim)
170 parDims.insert(lb, parDim);
177 DeviceType deviceType) {
179 auto *ctx = loopOp->getContext();
181 if (loopOp.hasVector(deviceType))
182 insertParDim(parDims, policy.
vectorDim(ctx));
183 if (loopOp.hasWorker(deviceType))
184 insertParDim(parDims, policy.
workerDim(ctx));
185 if (
auto gangDimValue = loopOp.getGangValue(GangArgType::Dim, deviceType)) {
186 if (
auto gangDimDefOp =
189 insertParDim(parDims, policy.
gangDim(ctx, gangLevel));
191 }
else if (loopOp.hasGang(deviceType)) {
192 insertParDim(parDims, policy.
gangDim(ctx, ParLevel::gang_dim1));
200template <
typename ComputeConstructT>
202assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
208 auto loc = computeOp->getLoc();
210 auto numGangs = computeOp.getNumGangsValues(deviceType);
211 if (numGangs.empty())
212 numGangs = computeOp.getNumGangsValues();
213 for (
auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
216 ParWidthOp::create(rewriter, loc,
218 rewriter, gangSize.getLoc(), indexTy, gangSize),
219 policy.
gangDim(ctx, gangLevel)));
222 Value numWorkers = computeOp.getNumWorkersValue(deviceType);
224 numWorkers = computeOp.getNumWorkersValue();
226 values.push_back(ParWidthOp::create(
233 Value vectorLength = computeOp.getVectorLengthValue(deviceType);
235 vectorLength = computeOp.getVectorLengthValue();
237 values.push_back(ParWidthOp::create(
240 indexTy, vectorLength),
249assignKnownLaunchArgs<SerialOp>(SerialOp, DeviceType,
RewriterBase &,
261 DeviceType deviceType)
264 LogicalResult matchAndRewrite(LoopOp loopOp,
266 if (loopOp.getUnstructured()) {
271 rewriter.
replaceOp(loopOp, executeRegion);
275 LoopParMode parMode = loopOp.getDefaultOrDeviceTypeParallelism(deviceType);
277 if (parMode == LoopParMode::loop_seq || isOpInSerialRegion(loopOp)) {
284 setParDimsAttr(parallelOp,
285 GPUParallelDimsAttr::seq(loopOp->getContext()));
287 }
else if (parMode == LoopParMode::loop_auto) {
289 assert(!isOpInSerialRegion(loopOp) &&
290 "Expected loop to be in non-serial region");
297 }
else if (!isOpInComputeRegion(loopOp) &&
299 loopOp->getParentOfType<FunctionOpInterface>())) {
308 assert(parMode == LoopParMode::loop_independent &&
309 "Expected loop to be independent");
315 getParallelDimensions(loopOp, policy, deviceType);
316 if (!parDims.empty()) {
318 GPUParallelDimsAttr::get(loopOp->getContext(), parDims);
319 setParDimsAttr(parallelOp, parDimsAttr);
329 DeviceType deviceType;
336template <
typename ComputeConstructT>
340 DeviceType deviceType)
342 deviceType(deviceType) {}
344 LogicalResult matchAndRewrite(ComputeConstructT computeOp,
348 KernelEnvironmentOp::createAndPopulate(computeOp, rewriter);
350 assignKnownLaunchArgs(computeOp, deviceType, rewriter, policy);
351 Region ®ion = computeOp.getRegion();
354 materializeConstantLiveInsIntoRegion(region, liveInValues, rewriter);
357 computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(),
358 ComputeConstructT::getOperationName(), region, rewriter, mapping);
359 if (!computeRegion) {
369 DeviceType deviceType;
376class ACCComputeLowering
377 :
public acc::impl::ACCComputeLoweringBase<ACCComputeLowering> {
379 using ACCComputeLoweringBase::ACCComputeLoweringBase;
381 void runOnOperation()
override {
382 auto op = getOperation();
383 auto *context = op.getContext();
391 loopPatterns.
insert<ACCLoopConversion>(context, policy, deviceType);
393 return signalPassFailure();
399 .
insert<ComputeOpConversion<ParallelOp>, ComputeOpConversion<KernelsOp>,
400 ComputeOpConversion<SerialOp>>(context, policy, deviceType);
402 return signalPassFailure();
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
ParDimAttrT vectorDim(MLIRContext *ctx) const
ParDimAttrT workerDim(MLIRContext *ctx) const
ParDimAttrT gangDim(MLIRContext *ctx, ParLevel level) const
Convenience methods for specific parallelism levels.
Default policy that provides the standard GPU mapping: gang(dim:1) -> BlockX (gridDim....
Specialization of arith.constant op that returns an integer value.
ParLevel getGangParLevel(int64_t gangDimValue)
Convert a gang dimension value (1, 2, or 3) to the corresponding ParLevel.
static constexpr StringLiteral getSpecializedRoutineAttrName()
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region ®ionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={}, ValueRange inputArgsToMap={})
Build an acc.compute_region operation by cloning a source region.
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
scf::ParallelOp convertACCLoopToSCFParallel(LoopOp loopOp, RewriterBase &rewriter)
Convert acc.loop to scf.parallel.
mlir::Operation * getEnclosingComputeOp(mlir::Region ®ion)
Used to obtain the enclosing compute construct operation that contains the provided region.
scf::ExecuteRegionOp convertUnstructuredACCLoopToSCFExecuteRegion(LoopOp loopOp, RewriterBase &rewriter)
Convert an unstructured acc.loop to scf.execute_region.
scf::ForOp convertACCLoopToSCFFor(LoopOp loopOp, RewriterBase &rewriter, bool enableCollapse)
Convert a structured acc.loop to scf.for.
ACCParMappingPolicy< mlir::acc::GPUParallelDimAttr > ACCToGPUMappingPolicy
Type alias for the GPU-specific mapping policy.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...