23 #include "llvm/Support/Debug.h"
26 #define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
27 #include "mlir/Dialect/Linalg/Passes.h.inc"
30 #define DEBUG_TYPE "linalg-specialization"
32 #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
33 (rewriter.replaceOpWithNewOp<NEWOP>( \
35 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
36 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
37 ValueRange{genericOp.getDpsInits()[0]}))
39 #define REPLACE_UNARY_OP(NEWOP) \
40 (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
41 ValueRange{genericOp.getDpsInputs()[0]}, \
42 ValueRange{genericOp.getDpsInits()[0]}))
59 Block *body = genericOp.getBody();
66 "binary op uses just one block arg");
97 enum class IndexMatchResult {
111 static IndexMatchResult matchOperandMap(
AffineMap map,
unsigned rowDimIdx,
112 unsigned expectedPosOfRowDim,
113 unsigned expectedPosOfColDim) {
115 auto exprOfRowDim = map.
getResults()[rowDimIdx];
116 auto exprOfColDim = map.
getResults()[rowDimIdx + 1];
121 return IndexMatchResult::Mismatch;
123 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
124 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
126 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
127 return IndexMatchResult::Match;
129 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
130 return IndexMatchResult::Transposed;
132 return IndexMatchResult::Mismatch;
139 template <
typename NamedOpTy>
140 static LinalgOp replaceWithMatmulVariant(
RewriterBase &rewriter, GenericOp op) {
142 op,
ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
148 static FailureOr<LinalgOp> specializeLinalgContractions(
RewriterBase &rewriter,
149 GenericOp genericOp) {
150 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
154 auto mapRange = genericOp.getIndexingMapsArray();
155 if (llvm::any_of(mapRange,
156 [](
AffineMap m) {
return !m.isProjectedPermutation(); }))
182 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
187 if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
188 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
189 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
196 auto indexingMaps = genericOp.getIndexingMapsArray();
197 if (llvm::any_of(indexingMaps, [&dims](
AffineMap m) {
198 return m.getResults().size() !=
199 dims.batch.size() + 2 ;
203 auto numOfBatchDims = dims.batch.size();
204 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
207 if (numOfBatchDims) {
211 if (llvm::any_of(indexingMaps, [numOfBatchDims](
AffineMap m) {
212 for (
unsigned i = 0; i < numOfBatchDims; ++i) {
213 auto expr = m.getResults()[i];
215 cast<AffineDimExpr>(expr).getPosition() != i)
224 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
226 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
228 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
230 if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
233 if (c != IndexMatchResult::Match ||
234 (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
238 if (numOfBatchDims) {
239 if (a == IndexMatchResult::Transposed)
240 return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
242 if (b == IndexMatchResult::Transposed)
243 return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
245 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
248 if (a == IndexMatchResult::Transposed)
249 return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
250 if (b == IndexMatchResult::Transposed)
251 return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
252 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
261 GenericOp genericOp) {
264 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
270 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
275 Operation *op = &genericOp.getBody()->front();
276 if (isa<math::ExpOp>(op)) {
284 Operation *op = &genericOp.getBody()->front();
285 if (isa<arith::AddFOp>(op)) {
289 if (isa<arith::SubFOp>(op)) {
293 if (isa<arith::MulFOp>(op)) {
297 if (isa<arith::DivFOp>(op)) {
304 return specializeLinalgContractions(rewriter, genericOp);
310 struct LinalgSpecializeGenericOpsPass
311 :
public impl::LinalgSpecializeGenericOpsPassBase<
312 LinalgSpecializeGenericOpsPass> {
314 using impl::LinalgSpecializeGenericOpsPassBase<
315 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
316 void runOnOperation()
override;
320 void LinalgSpecializeGenericOpsPass::runOnOperation() {
static MLIRContext * getContext(OpFoldResult val)
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
static bool areBinOpsSwapped(GenericOp genericOp)
#define REPLACE_UNARY_OP(NEWOP)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
OpOperand & getOpOperand(unsigned idx)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns)
Populates patterns with patterns to convert linalg.generic ops to named ops where possible.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
@ DimId
Dimensional identifier.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...