24#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
25#include "mlir/Dialect/Linalg/Passes.h.inc"
28#define DEBUG_TYPE "linalg-specialization"
30#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
31 (rewriter.replaceOpWithNewOp<NEWOP>( \
33 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
34 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
35 ValueRange{genericOp.getDpsInits()[0]}))
37#define REPLACE_UNARY_OP(NEWOP) \
38 (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
39 ValueRange{genericOp.getDpsInputs()[0]}, \
40 ValueRange{genericOp.getDpsInits()[0]}))
57 Block *body = genericOp.getBody();
64 "binary op uses just one block arg");
95enum class IndexMatchResult {
109static IndexMatchResult matchOperandMap(
AffineMap map,
unsigned rowDimIdx,
110 unsigned expectedPosOfRowDim,
111 unsigned expectedPosOfColDim) {
113 auto exprOfRowDim = map.
getResults()[rowDimIdx];
114 auto exprOfColDim = map.
getResults()[rowDimIdx + 1];
119 return IndexMatchResult::Mismatch;
121 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
122 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
124 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
125 return IndexMatchResult::Match;
127 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
128 return IndexMatchResult::Transposed;
130 return IndexMatchResult::Mismatch;
137template <
typename NamedOpTy>
138static LinalgOp replaceWithMatmulVariant(
RewriterBase &rewriter, GenericOp op) {
140 op,
ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
146static FailureOr<LinalgOp> specializeLinalgContractions(
RewriterBase &rewriter,
147 GenericOp genericOp) {
148 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
152 auto mapRange = genericOp.getIndexingMapsArray();
153 if (llvm::any_of(mapRange,
180 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
185 return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
186 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
187 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
192 auto indexingMaps = genericOp.getIndexingMapsArray();
193 if (llvm::any_of(indexingMaps, [&dims](
AffineMap m) {
195 dims.batch.size() + 2 ;
199 auto numOfBatchDims = dims.batch.size();
200 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
203 if (numOfBatchDims) {
207 if (llvm::any_of(indexingMaps, [numOfBatchDims](
AffineMap m) {
208 for (
unsigned i = 0; i < numOfBatchDims; ++i) {
211 cast<AffineDimExpr>(expr).getPosition() != i)
220 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
222 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
224 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
226 if (llvm::is_contained({a,
b, c}, IndexMatchResult::Mismatch))
229 if (c != IndexMatchResult::Match ||
230 (a == IndexMatchResult::Transposed &&
b == IndexMatchResult::Transposed))
234 if (numOfBatchDims) {
235 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
237 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
242template <
typename ConvOpTy>
243static FailureOr<LinalgOp>
244specializeToConvOp(
RewriterBase &rewriter, GenericOp genericOp,
253 if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
254 std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
255 std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
262 genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
268static FailureOr<LinalgOp> specializeLinalgConvolutions(
RewriterBase &rewriter,
269 GenericOp genericOp) {
271#define CONV_OP_SPECIALIZER(ConvOpTy) \
272 if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \
273 return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
331#undef CONV_OP_SPECIALIZER
341 GenericOp genericOp) {
345 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
353 genericOp, *fillValue, genericOp.getDpsInits()[0]);
358 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
360 if (equivalentToBroadcast) {
361 auto dims = *equivalentToBroadcast;
363 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
369 std::optional<SmallVector<int64_t>> equivalentToTranspose =
371 if (equivalentToTranspose) {
372 auto permutation = *equivalentToTranspose;
374 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
381 Operation *op = &genericOp.getBody()->front();
382 if (isa<math::ExpOp>(op)) {
391 Operation *op = &genericOp.getBody()->front();
392 if (isa<arith::AddFOp>(op)) {
396 if (isa<arith::SubFOp>(op)) {
400 if (isa<arith::MulFOp>(op)) {
404 if (isa<arith::DivFOp>(op)) {
412 return specializeLinalgContractions(rewriter, genericOp);
417 return specializeLinalgConvolutions(rewriter, genericOp);
423struct LinalgSpecializeGenericOpsPass
425 LinalgSpecializeGenericOpsPass> {
428 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
429 void runOnOperation()
override;
433void LinalgSpecializeGenericOpsPass::runOnOperation() {
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
#define CONV_OP_SPECIALIZER(ConvOpTy)
static bool areBinOpsSwapped(GenericOp genericOp)
#define REPLACE_UNARY_OP(NEWOP)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
DenseIntElementsAttr getI64TensorAttr(ArrayRef< int64_t > values)
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
OpOperand & getOpOperand(unsigned idx)
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns)
Populates patterns with patterns to convert linalg.generic ops to named ops where possible.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
const FrozenRewritePatternSet & patterns