25#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
26#include "mlir/Dialect/Linalg/Passes.h.inc"
29#define DEBUG_TYPE "linalg-specialization"
31#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
32 (rewriter.replaceOpWithNewOp<NEWOP>( \
34 ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
35 genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
36 ValueRange{genericOp.getDpsInits()[0]}))
57 Block *body = genericOp.getBody();
64 "binary op uses just one block arg");
90static FailureOr<LinalgOp>
92 bool emitCategoryOp) {
93 bool hasNonIdentityMaps =
94 !llvm::all_of(genericOp.getIndexingMapsArray(),
95 [](
AffineMap map) { return map.isIdentity(); });
98 if (hasNonIdentityMaps && !emitCategoryOp)
101 "non-identity indexing maps prevent specialization to named op");
106 auto replaceUnaryOp = [&](
auto namedOp, ElementwiseKind kind) -> LinalgOp {
109 newOp =
decltype(namedOp)::create(
110 rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
113 newOp = ElementwiseOp::create(
114 rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
115 genericOp.getDpsInits(),
116 ElementwiseKindAttr::get(rewriter.
getContext(), kind),
117 genericOp.getIndexingMaps());
124 Operation *op = &genericOp.getBody()->front();
126 if (isa<math::ExpOp>(op))
127 return replaceUnaryOp(ExpOp{}, ElementwiseKind::exp);
128 if (isa<math::LogOp>(op))
129 return replaceUnaryOp(LogOp{}, ElementwiseKind::log);
130 if (isa<math::AbsFOp>(op))
131 return replaceUnaryOp(AbsOp{}, ElementwiseKind::abs);
132 if (isa<math::CeilOp>(op))
133 return replaceUnaryOp(CeilOp{}, ElementwiseKind::ceil);
134 if (isa<math::FloorOp>(op))
135 return replaceUnaryOp(FloorOp{}, ElementwiseKind::floor);
136 if (isa<arith::NegFOp>(op))
137 return replaceUnaryOp(NegFOp{}, ElementwiseKind::negf);
138 if (
auto divOp = dyn_cast<arith::DivFOp>(op)) {
139 if (
auto constOp = dyn_cast_if_present<arith::ConstantOp>(
140 divOp.getLhs().getDefiningOp()))
141 if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
142 return replaceUnaryOp(ReciprocalOp{}, ElementwiseKind::reciprocal);
144 if (isa<math::RoundOp>(op))
145 return replaceUnaryOp(RoundOp{}, ElementwiseKind::round);
146 if (isa<math::SqrtOp>(op))
147 return replaceUnaryOp(SqrtOp{}, ElementwiseKind::sqrt);
148 if (isa<math::RsqrtOp>(op))
149 return replaceUnaryOp(RsqrtOp{}, ElementwiseKind::rsqrt);
150 if (
auto mulOp = dyn_cast<arith::MulFOp>(op);
151 mulOp && mulOp.getLhs() == mulOp.getRhs())
152 return replaceUnaryOp(SquareOp{}, ElementwiseKind::square);
153 if (isa<math::TanhOp>(op))
154 return replaceUnaryOp(TanhOp{}, ElementwiseKind::tanh);
155 if (isa<math::ErfOp>(op))
156 return replaceUnaryOp(ErfOp{}, ElementwiseKind::erf);
187enum class IndexMatchResult {
201static IndexMatchResult matchOperandMap(
AffineMap map,
unsigned rowDimIdx,
202 unsigned expectedPosOfRowDim,
203 unsigned expectedPosOfColDim) {
205 auto exprOfRowDim = map.
getResults()[rowDimIdx];
206 auto exprOfColDim = map.
getResults()[rowDimIdx + 1];
211 return IndexMatchResult::Mismatch;
213 auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
214 auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
216 if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
217 return IndexMatchResult::Match;
219 if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
220 return IndexMatchResult::Transposed;
222 return IndexMatchResult::Mismatch;
231template <
typename NamedOpTy>
232static LinalgOp replaceWithMatmulVariant(
RewriterBase &rewriter, GenericOp op,
233 std::optional<TypeFn> castTy,
238 if (castTy.has_value() && *castTy == TypeFn::cast_unsigned) {
240 "cast", TypeFnAttr::get(rewriter.
getContext(), *castTy));
241 attributes.push_back(castAttr);
248 return AffineMapAttr::get(map);
251 "indexing_maps", rewriter.
getArrayAttr(indexingMapsAttrVal));
252 attributes.push_back(indexingMapsAttr);
255 op,
ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
264static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
265 bool foundCastForMatmulOutput =
false;
267 genericOp.getBody()->walk([&](CastOpInterface castOp) {
276 if (!llvm::any_of(forwardSlice, [](
Operation *op) {
279 return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
281 foundCastForMatmulOutput =
true;
286 if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
287 castTyFns.push_back(TypeFn::cast_unsigned);
288 else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
289 castTyFns.push_back(TypeFn::cast_signed);
294 if (foundCastForMatmulOutput)
297 if (!castTyFns.empty()) {
301 if (!llvm::all_equal(castTyFns))
303 return castTyFns.front();
307 return TypeFn::cast_signed;
310static FailureOr<LinalgOp> specializeLinalgMmt4D(
RewriterBase &rewriter,
312 std::optional<TypeFn> castTy,
315 auto indexingMaps = genericOp.getIndexingMapsArray();
316 if (llvm::any_of(indexingMaps, [](
AffineMap m) {
321 auto aOuter = matchOperandMap(indexingMaps[0], 0, dims.
m[0], dims.
k[0]);
322 auto aInner = matchOperandMap(indexingMaps[0], 2, dims.
m[1], dims.
k[1]);
324 auto bOuter = matchOperandMap(indexingMaps[1], 0, dims.
k[0], dims.
n[0]);
325 auto bInner = matchOperandMap(indexingMaps[1], 2, dims.
k[1], dims.
n[1]);
327 auto cOuter = matchOperandMap(indexingMaps[2], 0, dims.
m[0], dims.
n[0]);
328 auto cInner = matchOperandMap(indexingMaps[2], 2, dims.
m[1], dims.
n[1]);
330 if (llvm::is_contained({aOuter, bOuter, cOuter}, IndexMatchResult::Mismatch))
332 if (llvm::is_contained({aInner, bInner, cInner}, IndexMatchResult::Mismatch))
338 return replaceWithMatmulVariant<Mmt4DOp>(rewriter, genericOp, castTy,
343static FailureOr<LinalgOp> specializeLinalgContractions(
RewriterBase &rewriter,
345 bool emitCategoryOp) {
346 if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
350 auto mapRange = genericOp.getIndexingMapsArray();
351 if (llvm::any_of(mapRange,
360 return (isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
361 (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
362 (isa<complex::MulOp>(first) && isa<complex::AddOp>(second));
368 std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
371 genericOp,
"contains invalid cast ops for the named matmul op");
375 return replaceWithMatmulVariant<ContractOp>(
376 rewriter, genericOp, castTy, genericOp.getIndexingMapsArray());
403 if (dims.
m.size() == 2 && dims.
n.size() == 2 && dims.
k.size() == 2)
404 return specializeLinalgMmt4D(rewriter, genericOp, castTy, dims);
405 if (dims.
m.size() != 1 || dims.
n.size() != 1 || dims.
k.size() != 1)
409 auto indexingMaps = genericOp.getIndexingMapsArray();
410 if (llvm::any_of(indexingMaps, [&dims](
AffineMap m) {
412 dims.
batch.size() + 2 ;
416 auto numOfBatchDims = dims.
batch.size();
417 if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
420 if (numOfBatchDims) {
424 if (llvm::any_of(indexingMaps, [numOfBatchDims](
AffineMap m) {
425 for (
unsigned i = 0; i < numOfBatchDims; ++i) {
428 cast<AffineDimExpr>(expr).getPosition() != i)
437 matchOperandMap(indexingMaps[0], numOfBatchDims, dims.
m[0], dims.
k[0]);
439 matchOperandMap(indexingMaps[1], numOfBatchDims, dims.
k[0], dims.
n[0]);
441 matchOperandMap(indexingMaps[2], numOfBatchDims, dims.
m[0], dims.
n[0]);
443 if (llvm::is_contained({a,
b, c}, IndexMatchResult::Mismatch))
447 auto *ctx = genericOp.getContext();
448 unsigned numLoopDims = numOfBatchDims + 3;
449 unsigned mIdx = numOfBatchDims;
450 unsigned nIdx = mIdx + 1;
451 unsigned kIdx = mIdx + 2;
454 auto makeMap = [&](IndexMatchResult match,
unsigned rowIdx,
unsigned colIdx) {
456 for (
unsigned i = 0; i < numOfBatchDims; ++i)
457 tensorDims.push_back(i);
458 if (match == IndexMatchResult::Transposed)
459 llvm::append_values(tensorDims, colIdx, rowIdx);
461 llvm::append_values(tensorDims, rowIdx, colIdx);
465 auto mapA = makeMap(a, mIdx, kIdx);
466 auto mapB = makeMap(
b, kIdx, nIdx);
467 auto mapC = makeMap(c, mIdx, nIdx);
472 if (numOfBatchDims) {
473 return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy,
476 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy,
482template <
typename ConvOpTy>
483static FailureOr<LinalgOp>
484specializeToConvOp(
RewriterBase &rewriter, GenericOp genericOp,
493 if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
494 std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
495 std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
502 genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
508static FailureOr<LinalgOp> specializeLinalgConvolutions(
RewriterBase &rewriter,
509 GenericOp genericOp) {
510#define CONV_OP_SPECIALIZER(ConvOpTy) \
511 if (std::optional<DilationsAndStrides> convParams = \
512 matchConvolutionOpOfType<ConvOpTy>(genericOp)) \
513 return specializeToConvOp<ConvOpTy>( \
514 rewriter, genericOp, convParams->dilations, convParams->strides); \
571#undef CONV_OP_SPECIALIZER
591 return specializeLinalgContractions(rewriter, genericOp,
600 genericOp,
"no matching category op specialization");
605 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
613 genericOp, *fillValue, genericOp.getDpsInits()[0]);
618 std::optional<SmallVector<int64_t>> equivalentToBroadcast =
620 if (equivalentToBroadcast) {
621 auto dims = *equivalentToBroadcast;
623 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
629 std::optional<SmallVector<int64_t>> equivalentToTranspose =
631 if (equivalentToTranspose) {
632 auto permutation = *equivalentToTranspose;
634 genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
642 Operation *op = &genericOp.getBody()->front();
643 if (isa<arith::AddFOp>(op)) {
647 if (isa<arith::SubFOp>(op)) {
651 if (isa<arith::MulFOp>(op)) {
655 if (isa<arith::DivFOp>(op)) {
663 return specializeLinalgConvolutions(rewriter, genericOp);
666 "no matching named op specialization");
670struct LinalgSpecializeGenericOpsPass
671 :
public impl::LinalgSpecializeGenericOpsPassBase<
672 LinalgSpecializeGenericOpsPass> {
674 using impl::LinalgSpecializeGenericOpsPassBase<
675 LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
676 void runOnOperation()
override;
680void LinalgSpecializeGenericOpsPass::runOnOperation() {
static llvm::ManagedStatic< PassManagerOptions > options
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
static FailureOr< LinalgOp > specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp, bool emitCategoryOp)
#define CONV_OP_SPECIALIZER(ConvOpTy)
static bool areBinOpsSwapped(GenericOp genericOp)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
DenseIntElementsAttr getI64TensorAttr(ArrayRef< int64_t > values)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
NamedAttribute getNamedAttr(StringRef name, Attribute val)
IRValueT get() const
Return the current value being used by this operand.
Operation is the basic unit of execution within MLIR.
OpOperand & getOpOperand(unsigned idx)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
static WalkResult advance()
static WalkResult interrupt()
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns, const GenericOpSpecializationOptions &options={})
Populates patterns with patterns to convert linalg.generic ops to named or category ops where possibl...
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
llvm::SetVector< T, Vector, Set, N > SetVector
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
SmallVector< unsigned, 2 > k