MLIR  20.0.0git
Interchange.cpp
Go to the documentation of this file.
1 //===- Interchange.cpp - Linalg interchange transformation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg interchange transformation.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Support/LLVM.h"
25 #include "llvm/ADT/ScopeExit.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <type_traits>
29 
30 #define DEBUG_TYPE "linalg-interchange"
31 
32 using namespace mlir;
33 using namespace mlir::linalg;
34 
35 static LogicalResult
37  ArrayRef<unsigned> interchangeVector) {
38  // Interchange vector must be non-empty and match the number of loops.
39  if (interchangeVector.empty() ||
40  genericOp.getNumLoops() != interchangeVector.size())
41  return failure();
42  // Permutation map must be invertible.
43  if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector,
44  genericOp.getContext())))
45  return failure();
46  return success();
47 }
48 
49 FailureOr<GenericOp>
50 mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
51  ArrayRef<unsigned> interchangeVector) {
52  if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
53  return rewriter.notifyMatchFailure(genericOp, "preconditions not met");
54 
55  // 1. Compute the inverse permutation map, it must be non-null since the
56  // preconditions are satisfied.
57  MLIRContext *context = genericOp.getContext();
58  AffineMap permutationMap = inversePermutation(
59  AffineMap::getPermutationMap(interchangeVector, context));
60  assert(permutationMap && "unexpected null map");
61 
62  // Start a guarded inplace update.
63  rewriter.startOpModification(genericOp);
64  auto guard = llvm::make_scope_exit(
65  [&]() { rewriter.finalizeOpModification(genericOp); });
66 
67  // 2. Compute the interchanged indexing maps.
68  SmallVector<AffineMap> newIndexingMaps;
69  for (OpOperand &opOperand : genericOp->getOpOperands()) {
70  AffineMap m = genericOp.getMatchingIndexingMap(&opOperand);
71  if (!permutationMap.isEmpty())
72  m = m.compose(permutationMap);
73  newIndexingMaps.push_back(m);
74  }
75  genericOp.setIndexingMapsAttr(
76  rewriter.getAffineMapArrayAttr(newIndexingMaps));
77 
78  // 3. Compute the interchanged iterator types.
79  ArrayRef<Attribute> itTypes = genericOp.getIteratorTypes().getValue();
80  SmallVector<Attribute> itTypesVector;
81  llvm::append_range(itTypesVector, itTypes);
82  SmallVector<int64_t> permutation(interchangeVector);
83  applyPermutationToVector(itTypesVector, permutation);
84  genericOp.setIteratorTypesAttr(rewriter.getArrayAttr(itTypesVector));
85 
86  // 4. Transform the index operations by applying the permutation map.
87  if (genericOp.hasIndexSemantics()) {
88  OpBuilder::InsertionGuard guard(rewriter);
89  for (IndexOp indexOp :
90  llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
91  rewriter.setInsertionPoint(indexOp);
92  SmallVector<Value> allIndices;
93  allIndices.reserve(genericOp.getNumLoops());
94  llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()),
95  std::back_inserter(allIndices), [&](uint64_t dim) {
96  return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
97  });
98  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
99  indexOp, permutationMap.getSubMap(indexOp.getDim()), allIndices);
100  }
101  }
102 
103  return genericOp;
104 }
static LogicalResult interchangeGenericOpPrecondition(GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Definition: Interchange.cpp:36
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:369
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:631
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:289
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:341
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:354
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:404
This class represents an operand of an operation.
Definition: Value.h:267
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
Include the generated interface declarations.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:768
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.