24#include "llvm/ADT/DenseMap.h"
25#include "llvm/Support/Debug.h"
29#define GEN_PASS_DEF_AFFINEPIPELINEDATATRANSFER
30#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
34#define DEBUG_TYPE "affine-pipeline-data-transfer"
40struct PipelineDataTransfer
41 :
public affine::impl::AffinePipelineDataTransferBase<
42 PipelineDataTransfer> {
43 void runOnOperation()
override;
44 void runOnAffineForOp(AffineForOp forOp);
46 std::vector<AffineForOp> forOps;
53std::unique_ptr<OperationPass<func::FuncOp>>
55 return std::make_unique<PipelineDataTransfer>();
62 assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp)));
63 if (
auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
64 return dmaStartOp.getTagMemRefOperandIndex();
76 auto *forBody = forOp.getBody();
77 OpBuilder bInner(forBody, forBody->begin());
80 auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
85 llvm::copy(oldShape, newShape.begin() + 1);
89 auto oldMemRefType = cast<MemRefType>(oldMemRef.
getType());
90 auto newMemRefType = doubleShape(oldMemRefType);
96 for (
const auto &dim : llvm::enumerate(oldMemRefType.getShape())) {
97 if (dim.value() == ShapedType::kDynamic)
98 allocOperands.push_back(bOuter.
createOrFold<memref::DimOp>(
99 forOp.getLoc(), oldMemRef, dim.index()));
103 Value newMemRef = memref::AllocOp::create(bOuter, forOp.getLoc(),
104 newMemRefType, allocOperands);
108 int64_t step = forOp.getStepAsInt();
111 auto ivModTwoOp = AffineApplyOp::create(bInner, forOp.getLoc(), modTwoMap,
112 forOp.getInductionVar());
116 auto userFilterFn = [&](
Operation *user) {
117 auto domInfo = std::make_unique<DominanceInfo>(
118 forOp->getParentOfType<FunctionOpInterface>());
119 return domInfo->dominates(&*forOp.getBody()->begin(), user);
121 if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef,
125 {}, userFilterFn))) {
127 forOp.emitError(
"memref replacement for double buffering failed"));
133 memref::DeallocOp::create(bOuter, forOp.getLoc(), newMemRef);
139void PipelineDataTransfer::runOnOperation() {
146 getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
147 for (
auto forOp : forOps)
148 runOnAffineForOp(forOp);
152static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
153 if (startOp.getTagMemRef() != waitOp.getTagMemRef())
155 auto startIndices = startOp.getTagIndices();
156 auto waitIndices = waitOp.getTagIndices();
159 for (
auto it = startIndices.begin(), wIt = waitIndices.begin(),
160 e = startIndices.end();
161 it != e; ++it, ++wIt) {
177 SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
181 for (
auto &op : *forOp.getBody()) {
182 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
183 if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
184 outgoingDmaOps.push_back(dmaStartOp);
188 for (
auto &op : *forOp.getBody()) {
190 if (isa<AffineDmaWaitOp>(op)) {
191 dmaFinishInsts.push_back(&op);
194 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
200 if (!dmaStartOp.isDestMemorySpaceFaster())
206 auto *it = outgoingDmaOps.begin();
207 for (; it != outgoingDmaOps.end(); ++it) {
208 if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
211 if (it != outgoingDmaOps.end())
215 auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
216 bool escapingUses =
false;
217 for (
auto *user :
memref.getUsers()) {
219 if (isa<memref::DeallocOp>(user))
221 if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
222 LLVM_DEBUG(llvm::dbgs()
223 <<
"can't pipeline: buffer is live out of loop\n";);
229 dmaStartInsts.push_back(&op);
233 for (
auto *dmaStartOp : dmaStartInsts) {
234 for (
auto *dmaFinishOp : dmaFinishInsts) {
236 cast<AffineDmaWaitOp>(dmaFinishOp))) {
237 startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
247void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
248 if (!forOp.getStaticTripCount()) {
249 LLVM_DEBUG(forOp.emitRemark(
"won't pipeline due to unknown trip count"));
253 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
256 if (startWaitPairs.empty()) {
257 LLVM_DEBUG(forOp.emitRemark(
"No dma start/finish pairs\n"));
269 for (
auto &pair : startWaitPairs) {
270 auto *dmaStartOp = pair.first;
271 Value oldMemRef = dmaStartOp->getOperand(
272 cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
276 LLVM_DEBUG(llvm::dbgs()
277 <<
"double buffering failed for" << dmaStartOp <<
"\n";);
292 dyn_cast<memref::DeallocOp>(*oldMemRef.
user_begin())) {
301 for (
auto &pair : startWaitPairs) {
302 auto *dmaFinishOp = pair.second;
303 Value oldTagMemRef = dmaFinishOp->getOperand(
getTagMemRefPos(*dmaFinishOp));
305 LLVM_DEBUG(llvm::dbgs() <<
"tag double buffering failed\n";);
315 dyn_cast<memref::DeallocOp>(*oldTagMemRef.
user_begin())) {
324 startWaitPairs.clear();
329 for (
auto &pair : startWaitPairs) {
330 auto *dmaStartOp = pair.first;
331 assert(isa<AffineDmaStartOp>(dmaStartOp));
332 instShiftMap[dmaStartOp] = 0;
334 SmallVector<AffineApplyOp, 4> sliceOps;
335 affine::createAffineComputationSlice(dmaStartOp, &sliceOps);
336 if (!sliceOps.empty()) {
337 for (
auto sliceOp : sliceOps) {
338 instShiftMap[sliceOp.getOperation()] = 0;
343 SmallVector<Operation *, 4> affineApplyInsts;
344 SmallVector<Value, 4> operands(dmaStartOp->getOperands());
346 for (
auto *op : affineApplyInsts) {
347 instShiftMap[op] = 0;
352 for (
auto &op : forOp.getBody()->without_terminator())
353 instShiftMap.try_emplace(&op, 1);
356 SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
358 for (
auto &op : forOp.getBody()->without_terminator()) {
359 assert(instShiftMap.contains(&op));
360 shifts[s++] = instShiftMap[&op];
365 op.setAttr(
"shift",
b.getI64IntegerAttr(shifts[s - 1]));
371 LLVM_DEBUG(llvm::dbgs() <<
"Shifts invalid - unexpected\n";);
376 LLVM_DEBUG(llvm::dbgs() <<
"op body skewing failed - unexpected\n";);
static unsigned getTagMemRefPos(Operation &dmaOp)
static void findMatchingStartFinishInsts(AffineForOp forOp, SmallVectorImpl< std::pair< Operation *, Operation * > > &startWaitPairs)
static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp)
static bool doubleBuffer(Value oldMemRef, AffineForOp forOp)
Doubles the buffer of the supplied memref on the specified 'affine.for' operation by adding a leading...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getAffineDimExpr(unsigned position)
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
user_iterator user_begin() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
LogicalResult affineForOpBodySkew(AffineForOp forOp, ArrayRef< uint64_t > shifts, bool unrollPrologueEpilogue=false)
Skew the operations in an affine.for's body with the specified operation-wise shifts.
void getReachableAffineApplyOps(ArrayRef< Value > operands, SmallVectorImpl< Operation * > &affineApplyOps)
Returns in affineApplyOps, the sequence of those AffineApplyOp Operations that are reachable via a se...
std::unique_ptr< OperationPass< func::FuncOp > > createPipelineDataTransferPass()
Creates a pass to pipeline explicit movement of data across levels of the memory hierarchy.
bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef< uint64_t > shifts)
Checks where SSA dominance would be violated if a for op's body operations are shifted by the specifi...
Include the generated interface declarations.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap