MLIR 22.0.0git
Interchange.cpp
Go to the documentation of this file.
1//===- Interchange.cpp - Linalg interchange transformation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg interchange transformation.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/AffineExpr.h"
22#include "mlir/Support/LLVM.h"
23#include "llvm/ADT/ScopeExit.h"
24
25#define DEBUG_TYPE "linalg-interchange"
26
27using namespace mlir;
28using namespace mlir::linalg;
29
30static LogicalResult
32 ArrayRef<unsigned> interchangeVector) {
33 // Interchange vector must be non-empty and match the number of loops.
34 if (interchangeVector.empty() ||
35 genericOp.getNumLoops() != interchangeVector.size())
36 return failure();
37 // Permutation map must be invertible.
39 genericOp.getContext())))
40 return failure();
41 return success();
42}
43
44FailureOr<GenericOp>
45mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
46 ArrayRef<unsigned> interchangeVector) {
47 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
48 return rewriter.notifyMatchFailure(genericOp, "preconditions not met");
49
50 // 1. Compute the inverse permutation map, it must be non-null since the
51 // preconditions are satisfied.
52 MLIRContext *context = genericOp.getContext();
53 AffineMap permutationMap = inversePermutation(
54 AffineMap::getPermutationMap(interchangeVector, context));
55 assert(permutationMap && "unexpected null map");
56
57 // Start a guarded inplace update.
58 rewriter.startOpModification(genericOp);
59 auto guard = llvm::make_scope_exit(
60 [&]() { rewriter.finalizeOpModification(genericOp); });
61
62 // 2. Compute the interchanged indexing maps.
63 SmallVector<AffineMap> newIndexingMaps;
64 for (OpOperand &opOperand : genericOp->getOpOperands()) {
65 AffineMap m = genericOp.getMatchingIndexingMap(&opOperand);
66 if (!permutationMap.isEmpty())
67 m = m.compose(permutationMap);
68 newIndexingMaps.push_back(m);
69 }
70 genericOp.setIndexingMapsAttr(
71 rewriter.getAffineMapArrayAttr(newIndexingMaps));
72
73 // 3. Compute the interchanged iterator types.
74 ArrayRef<Attribute> itTypes = genericOp.getIteratorTypes().getValue();
75 SmallVector<Attribute> itTypesVector;
76 llvm::append_range(itTypesVector, itTypes);
77 SmallVector<int64_t> permutation(interchangeVector);
78 applyPermutationToVector(itTypesVector, permutation);
79 genericOp.setIteratorTypesAttr(rewriter.getArrayAttr(itTypesVector));
80
81 // 4. Transform the index operations by applying the permutation map.
82 if (genericOp.hasIndexSemantics()) {
83 OpBuilder::InsertionGuard guard(rewriter);
84 for (IndexOp indexOp :
85 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
86 rewriter.setInsertionPoint(indexOp);
87 SmallVector<Value> allIndices;
88 allIndices.reserve(genericOp.getNumLoops());
89 llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()),
90 std::back_inserter(allIndices), [&](uint64_t dim) {
91 return IndexOp::create(rewriter, indexOp->getLoc(),
92 dim);
93 });
94 rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
95 indexOp, permutationMap.getSubMap(indexOp.getDim()), allIndices);
96 }
97 }
98
99 return genericOp;
100}
return success()
static LogicalResult interchangeGenericOpPrecondition(GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
This class represents an operand of an operation.
Definition Value.h:257
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Include the generated interface declarations.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.