MLIR  19.0.0git
Interchange.cpp
Go to the documentation of this file.
1 //===- Interchange.cpp - Linalg interchange transformation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg interchange transformation.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Support/LLVM.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <type_traits>
29 
30 #define DEBUG_TYPE "linalg-interchange"
31 
32 using namespace mlir;
33 using namespace mlir::linalg;
34 
35 static LogicalResult
37  ArrayRef<unsigned> interchangeVector) {
38  // Interchange vector must be non-empty and match the number of loops.
39  if (interchangeVector.empty() ||
40  genericOp.getNumLoops() != interchangeVector.size())
41  return failure();
42  // Permutation map must be invertible.
43  if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector,
44  genericOp.getContext())))
45  return failure();
46  return success();
47 }
48 
50 mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
51  ArrayRef<unsigned> interchangeVector) {
52  if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
53  return rewriter.notifyMatchFailure(genericOp, "preconditions not met");
54 
55  // 1. Compute the inverse permutation map, it must be non-null since the
56  // preconditions are satisfied.
57  MLIRContext *context = genericOp.getContext();
58  AffineMap permutationMap = inversePermutation(
59  AffineMap::getPermutationMap(interchangeVector, context));
60  assert(permutationMap && "unexpected null map");
61 
62  // Start a guarded inplace update.
63  rewriter.startOpModification(genericOp);
64  auto guard = llvm::make_scope_exit(
65  [&]() { rewriter.finalizeOpModification(genericOp); });
66 
67  // 2. Compute the interchanged indexing maps.
68  SmallVector<AffineMap> newIndexingMaps;
69  for (OpOperand &opOperand : genericOp->getOpOperands()) {
70  AffineMap m = genericOp.getMatchingIndexingMap(&opOperand);
71  if (!permutationMap.isEmpty())
72  m = m.compose(permutationMap);
73  newIndexingMaps.push_back(m);
74  }
75  genericOp.setIndexingMapsAttr(
76  rewriter.getAffineMapArrayAttr(newIndexingMaps));
77 
78  // 3. Compute the interchanged iterator types.
79  ArrayRef<Attribute> itTypes = genericOp.getIteratorTypes().getValue();
80  SmallVector<Attribute> itTypesVector;
81  llvm::append_range(itTypesVector, itTypes);
82  SmallVector<int64_t> permutation(interchangeVector.begin(),
83  interchangeVector.end());
84  applyPermutationToVector(itTypesVector, permutation);
85  genericOp.setIteratorTypesAttr(rewriter.getArrayAttr(itTypesVector));
86 
87  // 4. Transform the index operations by applying the permutation map.
88  if (genericOp.hasIndexSemantics()) {
89  OpBuilder::InsertionGuard guard(rewriter);
90  for (IndexOp indexOp :
91  llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
92  rewriter.setInsertionPoint(indexOp);
93  SmallVector<Value> allIndices;
94  allIndices.reserve(genericOp.getNumLoops());
95  llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()),
96  std::back_inserter(allIndices), [&](uint64_t dim) {
97  return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
98  });
99  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
100  indexOp, permutationMap.getSubMap(indexOp.getDim()), allIndices);
101  }
102  }
103 
104  return genericOp;
105 }
static LogicalResult interchangeGenericOpPrecondition(GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Definition: Interchange.cpp:36
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:353
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:615
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:540
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
This class represents an operand of an operation.
Definition: Value.h:267
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:753
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26