25#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
26#include "mlir/Dialect/Linalg/Passes.h.inc"
29#define DEBUG_TYPE "linalg-specialization"
31#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
32 (rewriter.replaceOpWithNewOp<NEWOP>( \
34 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
35 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
36 ValueRange{genericOp.getDpsInits()[0]}))
38#define REPLACE_UNARY_OP(NEWOP) \
39 (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
40 ValueRange{genericOp.getDpsInputs()[0]}, \
41 ValueRange{genericOp.getDpsInits()[0]}))
58 Block *body = genericOp.getBody();
65 "binary op uses just one block arg");
96enum class IndexMatchResult {
110static IndexMatchResult matchOperandMap(
AffineMap map,
unsigned rowDimIdx,
111 unsigned expectedPosOfRowDim,
112 unsigned expectedPosOfColDim) {
114 auto exprOfRowDim = map.
getResults()[rowDimIdx];
115 auto exprOfColDim = map.
getResults()[rowDimIdx + 1];
120 return IndexMatchResult::Mismatch;
122 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
123 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
125 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
126 return IndexMatchResult::Match;
128 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
129 return IndexMatchResult::Transposed;
131 return IndexMatchResult::Mismatch;
140template <
typename NamedOpTy>
141static LinalgOp replaceWithMatmulVariant(
RewriterBase &rewriter, GenericOp op,
142 std::optional<TypeFn> castTy) {
146 if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
148 "cast", TypeFnAttr::get(rewriter.
getContext(), *castTy))};
151 op,
ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
152 ValueRange{op.getDpsInits()[0]}, castAttrVec);
159static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
160 bool foundCastForMatmulOutput =
false;
162 genericOp.getBody()->walk([&](CastOpInterface castOp) {
171 if (!llvm::any_of(forwardSlice, [](
Operation *op) {
174 return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
176 foundCastForMatmulOutput =
true;
181 if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
182 castTyFns.push_back(TypeFn::cast_unsigned);
183 else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
184 castTyFns.push_back(TypeFn::cast_signed);
189 if (foundCastForMatmulOutput)
192 if (!castTyFns.empty()) {
196 if (!llvm::all_equal(castTyFns))
198 return castTyFns.front();
202 return TypeFn::cast_signed;
206static FailureOr<LinalgOp> specializeLinalgContractions(
RewriterBase &rewriter,
207 GenericOp genericOp) {
208 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
212 auto mapRange = genericOp.getIndexingMapsArray();
213 if (llvm::any_of(mapRange,
240 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
245 return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
246 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
247 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
252 auto indexingMaps = genericOp.getIndexingMapsArray();
253 if (llvm::any_of(indexingMaps, [&dims](
AffineMap m) {
255 dims.batch.size() + 2 ;
259 auto numOfBatchDims = dims.batch.size();
260 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
263 if (numOfBatchDims) {
267 if (llvm::any_of(indexingMaps, [numOfBatchDims](
AffineMap m) {
268 for (
unsigned i = 0; i < numOfBatchDims; ++i) {
271 cast<AffineDimExpr>(expr).getPosition() != i)
280 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
282 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
284 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
286 if (llvm::is_contained({a,
b, c}, IndexMatchResult::Mismatch))
289 if (c != IndexMatchResult::Match ||
290 (a == IndexMatchResult::Transposed &&
b == IndexMatchResult::Transposed))
295 std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
298 genericOp,
"contains invalid cast ops for the named matmul op");
301 if (numOfBatchDims) {
302 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
304 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
309template <
typename ConvOpTy>
310static FailureOr<LinalgOp>
311specializeToConvOp(
RewriterBase &rewriter, GenericOp genericOp,
320 if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
321 std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
322 std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
329 genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
335static FailureOr<LinalgOp> specializeLinalgConvolutions(
RewriterBase &rewriter,
336 GenericOp genericOp) {
337#define CONV_OP_SPECIALIZER(ConvOpTy) \
338 if (std::optional<DilationsAndStrides> convParams = \
339 matchConvolutionOpOfType<ConvOpTy>(genericOp)) \
340 return specializeToConvOp<ConvOpTy>( \
341 rewriter, genericOp, convParams->dilations, convParams->strides); \
398#undef CONV_OP_SPECIALIZER
408 GenericOp genericOp) {
412 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
420 genericOp, *fillValue, genericOp.getDpsInits()[0]);
425 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
427 if (equivalentToBroadcast) {
428 auto dims = *equivalentToBroadcast;
430 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
436 std::optional<SmallVector<int64_t>> equivalentToTranspose =
438 if (equivalentToTranspose) {
439 auto permutation = *equivalentToTranspose;
441 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
448 Operation *op = &genericOp.getBody()->front();
449 if (isa<math::ExpOp>(op)) {
458 Operation *op = &genericOp.getBody()->front();
459 if (isa<arith::AddFOp>(op)) {
463 if (isa<arith::SubFOp>(op)) {
467 if (isa<arith::MulFOp>(op)) {
471 if (isa<arith::DivFOp>(op)) {
479 return specializeLinalgContractions(rewriter, genericOp);
484 return specializeLinalgConvolutions(rewriter, genericOp);
490struct LinalgSpecializeGenericOpsPass
491 :
public impl::LinalgSpecializeGenericOpsPassBase<
492 LinalgSpecializeGenericOpsPass> {
494 using impl::LinalgSpecializeGenericOpsPassBase<
495 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
496 void runOnOperation()
override;
500void LinalgSpecializeGenericOpsPass::runOnOperation() {
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
#define CONV_OP_SPECIALIZER(ConvOpTy)
static bool areBinOpsSwapped(GenericOp genericOp)
#define REPLACE_UNARY_OP(NEWOP)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
DenseIntElementsAttr getI64TensorAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
NamedAttribute getNamedAttr(StringRef name, Attribute val)
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
OpOperand & getOpOperand(unsigned idx)
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
static WalkResult advance()
static WalkResult interrupt()
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns)
Populates patterns with patterns to convert linalg.generic ops to named ops where possible.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.