MLIR 22.0.0git
ACCLegalizeSerial.cpp
Go to the documentation of this file.
1//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass converts acc.serial into acc.parallel with num_gangs(1)
10// num_workers(1) vector_length(1).
11//
12// This transformation simplifies processing of acc regions by unifying the
13// handling of serial and parallel constructs. Since an OpenACC serial region
14// executes sequentially (like a parallel region with a single gang, worker, and
15// vector), this conversion is semantically equivalent while enabling code reuse
16// in later compilation stages.
17//
18//===----------------------------------------------------------------------===//
19
21
25#include "mlir/IR/Builders.h"
27#include "mlir/IR/Location.h"
28#include "mlir/IR/MLIRContext.h"
30#include "mlir/IR/Region.h"
31#include "mlir/IR/Value.h"
32#include "mlir/Support/LLVM.h"
35#include "llvm/Support/Debug.h"
36
37namespace mlir {
38namespace acc {
39#define GEN_PASS_DEF_ACCLEGALIZESERIAL
40#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
41} // namespace acc
42} // namespace mlir
43
44#define DEBUG_TYPE "acc-legalize-serial"
45
46namespace {
47using namespace mlir;
48
49struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> {
50 using OpRewritePattern<acc::SerialOp>::OpRewritePattern;
51
52 LogicalResult matchAndRewrite(acc::SerialOp serialOp,
53 PatternRewriter &rewriter) const override {
54
55 const Location loc = serialOp.getLoc();
56
57 // Create a container holding the constant value of 1 for use as the
58 // num_gangs, num_workers, and vector_length attributes.
60 auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
61 numValues.push_back(value);
62
63 // Since num_gangs is specified as both attributes and values, create a
64 // segment attribute.
65 llvm::SmallVector<int32_t> numGangsSegments;
66 numGangsSegments.push_back(numValues.size());
67 auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments);
68
69 // Create a device_type attribute set to `none` which ensures that
70 // the parallel dimensions specification applies to the default clauses.
72 auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
73 rewriter.getContext(), mlir::acc::DeviceType::None);
74 crtDeviceTypes.push_back(crtDeviceTypeAttr);
75 auto devTypeAttr =
76 mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes);
77
78 LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n");
79
80 // Create a new acc.parallel op with the same operands - except include the
81 // num_gangs, num_workers, and vector_length attributes.
82 acc::ParallelOp parOp = acc::ParallelOp::create(
83 rewriter, loc, serialOp.getAsyncOperands(),
84 serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(),
85 serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(),
86 serialOp.getWaitOperandsDeviceTypeAttr(),
87 serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues,
88 gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues,
89 devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(),
90 serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(),
91 serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(),
92 serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(),
93 serialOp.getCombinedAttr());
94
95 parOp.getRegion().takeBody(serialOp.getRegion());
96
97 LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n");
98 rewriter.replaceOp(serialOp, parOp);
99
100 return success();
101 }
102};
103
104class ACCLegalizeSerial
105 : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> {
106public:
107 using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase;
108 void runOnOperation() override {
109 func::FuncOp funcOp = getOperation();
110 MLIRContext *context = funcOp.getContext();
112 patterns.insert<ACCSerialOpConversion>(context);
113 (void)applyPatternsGreedily(funcOp, std::move(patterns));
114 }
115};
116
117} // namespace
return success()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...