61#include "llvm/ADT/STLExtras.h"
65#define GEN_PASS_DEF_ACCCOMPUTELOWERING
66#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
70#define DEBUG_TYPE "acc-compute-lowering"
88template <
typename ComputeOpT>
89static bool isGangWorkerVectorAllOne(ComputeOpT op) {
90 auto numGangs = op.getNumGangsValues();
93 for (
Value gangSize : numGangs) {
97 Value numWorkers = op.getNumWorkersValue();
100 Value vectorLength = op.getVectorLengthValue();
111template <
typename ComputeOpT>
112static bool isEffectivelySerial(ComputeOpT op) {
113 return isGangWorkerVectorAllOne(op);
116static bool isOpInComputeRegion(
Operation *op) {
121static bool isOpInSerialRegion(
Operation *op) {
123 return isEffectivelySerial(parallelOp);
125 return isEffectivelySerial(kernelsOp);
129 return computeRegion.isEffectivelySerial();
132 auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
134 if (attr && attr.getLevel().getValue() == ParLevel::seq)
141static void setParDimsAttr(
Operation *op, GPUParallelDimsAttr attr) {
142 op->setAttr(GPUParallelDimsAttr::name, attr);
148static void materializeConstantLiveInsIntoRegion(
Region ®ion,
152 for (
Value v : liveInValues) {
158 "constants must have a single result");
159 constantLiveIns.push_back(v);
162 if (constantLiveIns.empty())
168 for (
Value v : constantLiveIns) {
171 liveInValues.remove(v);
178 GPUParallelDimAttr parDim) {
179 GPUParallelDimAttr *lb = llvm::lower_bound(
181 [](
const GPUParallelDimAttr &a,
const GPUParallelDimAttr &
b) {
182 return a.getOrder() >
b.getOrder();
184 if (lb == parDims.end() || *lb != parDim)
185 parDims.insert(lb, parDim);
192 DeviceType deviceType) {
194 auto *ctx = loopOp->getContext();
196 if (loopOp.hasVector(deviceType))
197 insertParDim(parDims, policy.
vectorDim(ctx));
198 if (loopOp.hasWorker(deviceType))
199 insertParDim(parDims, policy.
workerDim(ctx));
200 if (
auto gangDimValue = loopOp.getGangValue(GangArgType::Dim, deviceType)) {
201 if (
auto gangDimDefOp =
204 insertParDim(parDims, policy.
gangDim(ctx, gangLevel));
206 }
else if (loopOp.hasGang(deviceType)) {
207 insertParDim(parDims, policy.
gangDim(ctx, ParLevel::gang_dim1));
217template <
typename ComputeConstructT>
219assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
223 auto loc = computeOp->getLoc();
225 if constexpr (std::is_same_v<ComputeConstructT, SerialOp>) {
226 return {ParWidthOp::create(rewriter, loc,
Value(), policy.
seqDim(ctx))};
227 }
else if constexpr (llvm::is_one_of<ComputeConstructT, ParallelOp,
229 if (isEffectivelySerial(computeOp))
230 return {ParWidthOp::create(rewriter, loc,
Value(), policy.
seqDim(ctx))};
235 auto numGangs = computeOp.getNumGangsValues(deviceType);
236 if (numGangs.empty())
237 numGangs = computeOp.getNumGangsValues();
238 for (
auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
240 values.push_back(ParWidthOp::create(
244 policy.
gangDim(ctx, gangLevel)));
247 Value numWorkers = computeOp.getNumWorkersValue(deviceType);
249 numWorkers = computeOp.getNumWorkersValue();
251 values.push_back(ParWidthOp::create(
254 indexTy, numWorkers),
258 Value vectorLength = computeOp.getVectorLengthValue(deviceType);
260 vectorLength = computeOp.getVectorLengthValue();
262 values.push_back(ParWidthOp::create(
265 indexTy, vectorLength),
270 llvm_unreachable(
"assignKnownLaunchArgs: expected parallel, kernels, or "
282 DeviceType deviceType)
285 LogicalResult matchAndRewrite(LoopOp loopOp,
287 if (loopOp.getUnstructured()) {
292 rewriter.
replaceOp(loopOp, executeRegion);
296 LoopParMode parMode = loopOp.getDefaultOrDeviceTypeParallelism(deviceType);
298 if (parMode == LoopParMode::loop_seq || isOpInSerialRegion(loopOp)) {
305 setParDimsAttr(parallelOp,
306 GPUParallelDimsAttr::seq(loopOp->getContext()));
308 }
else if (parMode == LoopParMode::loop_auto) {
310 assert(!isOpInSerialRegion(loopOp) &&
311 "Expected loop to be in non-serial region");
318 }
else if (!isOpInComputeRegion(loopOp) &&
320 loopOp->getParentOfType<FunctionOpInterface>())) {
329 assert(parMode == LoopParMode::loop_independent &&
330 "Expected loop to be independent");
336 getParallelDimensions(loopOp, policy, deviceType);
337 if (!parDims.empty()) {
339 GPUParallelDimsAttr::get(loopOp->getContext(), parDims);
340 setParDimsAttr(parallelOp, parDimsAttr);
350 DeviceType deviceType;
357template <
typename ComputeConstructT>
361 DeviceType deviceType)
363 deviceType(deviceType) {}
365 LogicalResult matchAndRewrite(ComputeConstructT computeOp,
369 KernelEnvironmentOp::createAndPopulate(computeOp, rewriter);
371 assignKnownLaunchArgs(computeOp, deviceType, rewriter, policy);
372 Region ®ion = computeOp.getRegion();
375 materializeConstantLiveInsIntoRegion(region, liveInValues, rewriter);
378 computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(),
379 ComputeConstructT::getOperationName(), region, rewriter, mapping);
380 if (!computeRegion) {
390 DeviceType deviceType;
397class ACCComputeLowering
398 :
public acc::impl::ACCComputeLoweringBase<ACCComputeLowering> {
400 using ACCComputeLoweringBase::ACCComputeLoweringBase;
402 void runOnOperation()
override {
403 auto op = getOperation();
404 auto *context = op.getContext();
412 loopPatterns.
insert<ACCLoopConversion>(context, policy, deviceType);
414 return signalPassFailure();
420 .
insert<ComputeOpConversion<ParallelOp>, ComputeOpConversion<KernelsOp>,
421 ComputeOpConversion<SerialOp>>(context, policy, deviceType);
423 return signalPassFailure();
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
ParDimAttrT seqDim(MLIRContext *ctx) const
ParDimAttrT vectorDim(MLIRContext *ctx) const
ParDimAttrT workerDim(MLIRContext *ctx) const
ParDimAttrT gangDim(MLIRContext *ctx, ParLevel level) const
Convenience methods for specific parallelism levels.
Default policy that provides the standard GPU mapping: gang(dim:1) -> BlockX (gridDim....
Specialization of arith.constant op that returns an integer value.
ParLevel getGangParLevel(int64_t gangDimValue)
Convert a gang dimension value (1, 2, or 3) to the corresponding ParLevel.
static constexpr StringLiteral getSpecializedRoutineAttrName()
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region ®ionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={}, ValueRange inputArgsToMap={})
Build an acc.compute_region operation by cloning a source region.
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
scf::ParallelOp convertACCLoopToSCFParallel(LoopOp loopOp, RewriterBase &rewriter)
Convert acc.loop to scf.parallel.
mlir::Operation * getEnclosingComputeOp(mlir::Region ®ion)
Used to obtain the enclosing compute construct operation that contains the provided region.
scf::ExecuteRegionOp convertUnstructuredACCLoopToSCFExecuteRegion(LoopOp loopOp, RewriterBase &rewriter)
Convert an unstructured acc.loop to scf.execute_region.
scf::ForOp convertACCLoopToSCFFor(LoopOp loopOp, RewriterBase &rewriter, bool enableCollapse)
Convert a structured acc.loop to scf.for.
ACCParMappingPolicy< mlir::acc::GPUParallelDimAttr > ACCToGPUMappingPolicy
Type alias for the GPU-specific mapping policy.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...