25#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
26#include "mlir/Dialect/Linalg/Passes.h.inc"
29#define DEBUG_TYPE "linalg-specialization"
31#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
32 (rewriter.replaceOpWithNewOp<NEWOP>( \
34 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
35 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
36 ValueRange{genericOp.getDpsInits()[0]}))
38#define REPLACE_UNARY_OP(NEWOP) \
39 (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
40 ValueRange{genericOp.getDpsInputs()[0]}, \
41 ValueRange{genericOp.getDpsInits()[0]}))
58 Block *body = genericOp.getBody();
65 "binary op uses just one block arg");
96enum class IndexMatchResult {
110static IndexMatchResult matchOperandMap(
AffineMap map,
unsigned rowDimIdx,
111 unsigned expectedPosOfRowDim,
112 unsigned expectedPosOfColDim) {
114 auto exprOfRowDim = map.
getResults()[rowDimIdx];
115 auto exprOfColDim = map.
getResults()[rowDimIdx + 1];
120 return IndexMatchResult::Mismatch;
122 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
123 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
125 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
126 return IndexMatchResult::Match;
128 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
129 return IndexMatchResult::Transposed;
131 return IndexMatchResult::Mismatch;
140template <
typename NamedOpTy>
141static LinalgOp replaceWithMatmulVariant(
RewriterBase &rewriter, GenericOp op,
142 std::optional<TypeFn> castTy,
147 if (castTy.has_value() && *castTy == TypeFn::cast_unsigned) {
149 "cast", TypeFnAttr::get(rewriter.
getContext(), *castTy));
150 attributes.push_back(castAttr);
157 return AffineMapAttr::get(map);
160 "indexing_maps", rewriter.
getArrayAttr(indexingMapsAttrVal));
161 attributes.push_back(indexingMapsAttr);
164 op,
ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
173static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
174 bool foundCastForMatmulOutput =
false;
176 genericOp.getBody()->walk([&](CastOpInterface castOp) {
185 if (!llvm::any_of(forwardSlice, [](
Operation *op) {
188 return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
190 foundCastForMatmulOutput =
true;
195 if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
196 castTyFns.push_back(TypeFn::cast_unsigned);
197 else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
198 castTyFns.push_back(TypeFn::cast_signed);
203 if (foundCastForMatmulOutput)
206 if (!castTyFns.empty()) {
210 if (!llvm::all_equal(castTyFns))
212 return castTyFns.front();
216 return TypeFn::cast_signed;
220static FailureOr<LinalgOp> specializeLinalgContractions(
RewriterBase &rewriter,
222 bool emitCategoryOp) {
223 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
227 auto mapRange = genericOp.getIndexingMapsArray();
228 if (llvm::any_of(mapRange,
237 return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
238 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
239 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
245 std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
248 genericOp,
"contains invalid cast ops for the named matmul op");
252 return replaceWithMatmulVariant<ContractOp>(
253 rewriter, genericOp, castTy, genericOp.getIndexingMapsArray());
280 if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
284 auto indexingMaps = genericOp.getIndexingMapsArray();
285 if (llvm::any_of(indexingMaps, [&dims](
AffineMap m) {
287 dims.batch.size() + 2 ;
291 auto numOfBatchDims = dims.batch.size();
292 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
295 if (numOfBatchDims) {
299 if (llvm::any_of(indexingMaps, [numOfBatchDims](
AffineMap m) {
300 for (
unsigned i = 0; i < numOfBatchDims; ++i) {
303 cast<AffineDimExpr>(expr).getPosition() != i)
312 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
314 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
316 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
318 if (llvm::is_contained({a,
b, c}, IndexMatchResult::Mismatch))
322 auto *ctx = genericOp.getContext();
323 unsigned numLoopDims = numOfBatchDims + 3;
324 unsigned mIdx = numOfBatchDims;
325 unsigned nIdx = mIdx + 1;
326 unsigned kIdx = mIdx + 2;
329 auto makeMap = [&](IndexMatchResult match,
unsigned rowIdx,
unsigned colIdx) {
331 for (
unsigned i = 0; i < numOfBatchDims; ++i)
332 tensorDims.push_back(i);
333 if (match == IndexMatchResult::Transposed)
334 llvm::append_values(tensorDims, colIdx, rowIdx);
336 llvm::append_values(tensorDims, rowIdx, colIdx);
340 auto mapA = makeMap(a, mIdx, kIdx);
341 auto mapB = makeMap(
b, kIdx, nIdx);
342 auto mapC = makeMap(c, mIdx, nIdx);
347 if (numOfBatchDims) {
348 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy,
351 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy,
357template <
typename ConvOpTy>
358static FailureOr<LinalgOp>
359specializeToConvOp(
RewriterBase &rewriter, GenericOp genericOp,
368 if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
369 std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
370 std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
377 genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
383static FailureOr<LinalgOp> specializeLinalgConvolutions(
RewriterBase &rewriter,
384 GenericOp genericOp) {
385#define CONV_OP_SPECIALIZER(ConvOpTy) \
386 if (std::optional<DilationsAndStrides> convParams = \
387 matchConvolutionOpOfType<ConvOpTy>(genericOp)) \
388 return specializeToConvOp<ConvOpTy>( \
389 rewriter, genericOp, convParams->dilations, convParams->strides); \
446#undef CONV_OP_SPECIALIZER
460 return specializeLinalgContractions(rewriter, genericOp,
469 genericOp,
"no matching category op specialization");
474 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
482 genericOp, *fillValue, genericOp.getDpsInits()[0]);
487 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
489 if (equivalentToBroadcast) {
490 auto dims = *equivalentToBroadcast;
492 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
498 std::optional<SmallVector<int64_t>> equivalentToTranspose =
500 if (equivalentToTranspose) {
501 auto permutation = *equivalentToTranspose;
503 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
510 Operation *op = &genericOp.getBody()->front();
511 if (isa<math::ExpOp>(op)) {
520 Operation *op = &genericOp.getBody()->front();
521 if (isa<arith::AddFOp>(op)) {
525 if (isa<arith::SubFOp>(op)) {
529 if (isa<arith::MulFOp>(op)) {
533 if (isa<arith::DivFOp>(op)) {
541 return specializeLinalgConvolutions(rewriter, genericOp);
544 "no matching named op specialization");
548struct LinalgSpecializeGenericOpsPass
550 LinalgSpecializeGenericOpsPass> {
553 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
554 void runOnOperation()
override;
558void LinalgSpecializeGenericOpsPass::runOnOperation() {
static llvm::ManagedStatic< PassManagerOptions > options
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
#define CONV_OP_SPECIALIZER(ConvOpTy)
static bool areBinOpsSwapped(GenericOp genericOp)
#define REPLACE_UNARY_OP(NEWOP)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
DenseIntElementsAttr getI64TensorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
NamedAttribute getNamedAttr(StringRef name, Attribute val)
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
OpOperand & getOpOperand(unsigned idx)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
static WalkResult advance()
static WalkResult interrupt()
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns, const GenericOpSpecializationOptions &options={})
Populates patterns with patterns to convert linalg.generic ops to named or category ops where possibl...
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
llvm::SetVector< T, Vector, Set, N > SetVector
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.