61#include "llvm/ADT/STLExtras.h"
65#define GEN_PASS_DEF_ACCCOMPUTELOWERING
66#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
70#define DEBUG_TYPE "acc-compute-lowering"
88template <
typename ComputeOpT>
89static bool isGangWorkerVectorAllOne(ComputeOpT op) {
90 auto numGangs = op.getNumGangsValues();
93 for (
Value gangSize : numGangs) {
97 Value numWorkers = op.getNumWorkersValue();
100 Value vectorLength = op.getVectorLengthValue();
111template <
typename ComputeOpT>
112static bool isEffectivelySerial(ComputeOpT op) {
113 return isGangWorkerVectorAllOne(op);
116static bool isOpInComputeRegion(
Operation *op) {
121static bool isOpInSerialRegion(
Operation *op) {
123 return isEffectivelySerial(parallelOp);
125 return isEffectivelySerial(kernelsOp);
129 return computeRegion.isEffectivelySerial();
132 auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
134 if (attr && attr.getLevel().getValue() == ParLevel::seq)
141static void setParDimsAttr(
Operation *op, GPUParallelDimsAttr attr) {
142 op->setAttr(GPUParallelDimsAttr::name, attr);
148static void materializeConstantLiveInsIntoRegion(
Region ®ion,
152 for (
Value v : liveInValues) {
158 "constants must have a single result");
159 constantLiveIns.push_back(v);
162 if (constantLiveIns.empty())
168 for (
Value v : constantLiveIns) {
171 liveInValues.remove(v);
178 GPUParallelDimAttr parDim) {
179 GPUParallelDimAttr *lb = llvm::lower_bound(
181 [](
const GPUParallelDimAttr &a,
const GPUParallelDimAttr &
b) {
182 return a.getOrder() >
b.getOrder();
184 if (lb == parDims.end() || *lb != parDim)
185 parDims.insert(lb, parDim);
191static DeviceType getGangWorkerVectorDeviceType(LoopOp loopOp,
192 DeviceType deviceType) {
193 if (deviceType != DeviceType::None &&
194 loopOp.hasAnyGangWorkerVector(deviceType))
196 return DeviceType::None;
199template <
typename ComputeConstructT>
200static DeviceType getParDimsDeviceType(ComputeConstructT computeOp,
201 DeviceType deviceType) {
202 if (deviceType != DeviceType::None &&
203 computeOp.hasAnyGangWorkerVector(deviceType))
205 return DeviceType::None;
212 DeviceType deviceType) {
213 deviceType = getGangWorkerVectorDeviceType(loopOp, deviceType);
215 auto *ctx = loopOp->getContext();
217 if (loopOp.hasVector(deviceType))
218 insertParDim(parDims, policy.
vectorDim(ctx));
219 if (loopOp.hasWorker(deviceType))
220 insertParDim(parDims, policy.
workerDim(ctx));
221 if (
auto gangDimValue = loopOp.getGangValue(GangArgType::Dim, deviceType)) {
222 if (
auto gangDimDefOp =
225 insertParDim(parDims, policy.
gangDim(ctx, gangLevel));
227 }
else if (loopOp.hasGang(deviceType)) {
228 insertParDim(parDims, policy.
gangDim(ctx, ParLevel::gang_dim1));
238template <
typename ComputeConstructT>
240assignKnownLaunchArgs(ComputeConstructT computeOp, DeviceType deviceType,
244 auto loc = computeOp->getLoc();
246 if constexpr (std::is_same_v<ComputeConstructT, SerialOp>) {
247 return {ParWidthOp::create(rewriter, loc,
Value(), policy.
seqDim(ctx))};
248 }
else if constexpr (llvm::is_one_of<ComputeConstructT, ParallelOp,
250 if (isEffectivelySerial(computeOp))
251 return {ParWidthOp::create(rewriter, loc,
Value(), policy.
seqDim(ctx))};
253 deviceType = getParDimsDeviceType(computeOp, deviceType);
258 auto numGangs = computeOp.getNumGangsValues(deviceType);
259 for (
auto [gangDimIdx, gangSize] : llvm::enumerate(numGangs)) {
261 values.push_back(ParWidthOp::create(
265 policy.
gangDim(ctx, gangLevel)));
268 Value numWorkers = computeOp.getNumWorkersValue(deviceType);
270 values.push_back(ParWidthOp::create(
273 indexTy, numWorkers),
277 Value vectorLength = computeOp.getVectorLengthValue(deviceType);
279 values.push_back(ParWidthOp::create(
282 indexTy, vectorLength),
287 llvm_unreachable(
"assignKnownLaunchArgs: expected parallel, kernels, or "
299 DeviceType deviceType)
302 LogicalResult matchAndRewrite(LoopOp loopOp,
304 if (loopOp.getUnstructured()) {
309 rewriter.
replaceOp(loopOp, executeRegion);
313 LoopParMode parMode = loopOp.getDefaultOrDeviceTypeParallelism(deviceType);
315 if (parMode == LoopParMode::loop_seq || isOpInSerialRegion(loopOp)) {
322 setParDimsAttr(parallelOp,
323 GPUParallelDimsAttr::seq(loopOp->getContext()));
325 }
else if (parMode == LoopParMode::loop_auto) {
327 assert(!isOpInSerialRegion(loopOp) &&
328 "Expected loop to be in non-serial region");
335 getParallelDimensions(loopOp, policy, deviceType);
336 if (!parDims.empty()) {
338 GPUParallelDimsAttr::get(loopOp->getContext(), parDims);
339 setParDimsAttr(forOp, parDimsAttr);
342 }
else if (!isOpInComputeRegion(loopOp) &&
344 loopOp->getParentOfType<FunctionOpInterface>())) {
353 assert(parMode == LoopParMode::loop_independent &&
354 "Expected loop to be independent");
360 getParallelDimensions(loopOp, policy, deviceType);
361 if (!parDims.empty()) {
363 GPUParallelDimsAttr::get(loopOp->getContext(), parDims);
364 setParDimsAttr(parallelOp, parDimsAttr);
374 DeviceType deviceType;
381template <
typename ComputeConstructT>
385 DeviceType deviceType)
387 deviceType(deviceType) {}
389 LogicalResult matchAndRewrite(ComputeConstructT computeOp,
393 KernelEnvironmentOp::createAndPopulate(computeOp, deviceType, rewriter);
395 assignKnownLaunchArgs(computeOp, deviceType, rewriter, policy);
396 Region ®ion = computeOp.getRegion();
399 materializeConstantLiveInsIntoRegion(region, liveInValues, rewriter);
402 computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(),
403 ComputeConstructT::getOperationName(), region, rewriter, mapping);
404 if (!computeRegion) {
414 DeviceType deviceType;
421class ACCComputeLowering
422 :
public acc::impl::ACCComputeLoweringBase<ACCComputeLowering> {
424 using ACCComputeLoweringBase::ACCComputeLoweringBase;
426 void runOnOperation()
override {
427 auto op = getOperation();
428 auto *context = op.getContext();
436 loopPatterns.
insert<ACCLoopConversion>(context, policy, deviceType);
438 return signalPassFailure();
444 .
insert<ComputeOpConversion<ParallelOp>, ComputeOpConversion<KernelsOp>,
445 ComputeOpConversion<SerialOp>>(context, policy, deviceType);
447 return signalPassFailure();
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
ParDimAttrT seqDim(MLIRContext *ctx) const
ParDimAttrT vectorDim(MLIRContext *ctx) const
ParDimAttrT workerDim(MLIRContext *ctx) const
ParDimAttrT gangDim(MLIRContext *ctx, ParLevel level) const
Convenience methods for specific parallelism levels.
Default policy that provides the standard GPU mapping: gang(dim:1) -> BlockX (gridDim....
Specialization of arith.constant op that returns an integer value.
ParLevel getGangParLevel(int64_t gangDimValue)
Convert a gang dimension value (1, 2, or 3) to the corresponding ParLevel.
static constexpr StringLiteral getSpecializedRoutineAttrName()
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region ®ionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={}, ValueRange inputArgsToMap={})
Build an acc.compute_region operation by cloning a source region.
bool isSpecializedAccRoutine(mlir::Operation *op)
Used to check whether this is a specialized accelerator version of acc routine function.
scf::ParallelOp convertACCLoopToSCFParallel(LoopOp loopOp, RewriterBase &rewriter)
Convert acc.loop to scf.parallel.
mlir::Operation * getEnclosingComputeOp(mlir::Region ®ion)
Used to obtain the enclosing compute construct operation that contains the provided region.
scf::ExecuteRegionOp convertUnstructuredACCLoopToSCFExecuteRegion(LoopOp loopOp, RewriterBase &rewriter)
Convert an unstructured acc.loop to scf.execute_region.
scf::ForOp convertACCLoopToSCFFor(LoopOp loopOp, RewriterBase &rewriter, bool enableCollapse)
Convert a structured acc.loop to scf.for.
ACCParMappingPolicy< mlir::acc::GPUParallelDimAttr > ACCToGPUMappingPolicy
Type alias for the GPU-specific mapping policy.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...