MLIR  17.0.0git
Macros | Functions
LinalgTransformOps.cpp File Reference
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <type_traits>
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"

Go to the source code of this file.

Macros

#define DEBUG_TYPE   "linalg-transforms"
 
#define DBGS()   (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
#define DBGSNL()   (llvm::dbgs() << "\n")
 
#define DOWNSCALE(trans)
 
#define DOWNSCALE_CALL(a, b)   DownscaleSizeOneWindowed2DConvolution<a, b>
 
#define DOWNSCALE_NORMAL(a, b)   DOWNSCALE(DOWNSCALE_CALL(a, b))
 
#define GET_OP_CLASSES
 

Functions

template<typename PatternTy , typename... Args>
static FailureOr< LinalgOp > tryApply (Operation *operation, Args &&...args)
 Attempts to apply the pattern specified as template argument to the given operation. More...
 
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
 Assuming that ofr is an index attr or a transform dialect handle mapped to exactly one op with one index result, return that value. More...
 
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, Value packedHandle)
 
template<typename Range >
static LogicalResult applyTilingToAll (RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
 Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops. More...
 
static OperationreplaceForAllWithNewSignature (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
 Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op. More...
 
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 Find the first "extract" user of producerOp and tile it right before its use. More...
 
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp, otherwise bail. More...
 
static OperationcloneAndFuseFirstUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 
static void forciblyReplaceReferencedPayloadOperation (TransformState &state, Operation *payload, Operation *replacement)
 Replaces payload with replacement in all handles stored in the state. More...
 
static void printMultitileSizesTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
 
static ParseResult parseMultitileSizesTypes (OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
 
static FailureOr< PackResultpackMatmulGreedily (RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
 Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel dimensions and k is a proper reduction dimension. More...
 
template<typename RelayoutOpTy >
bool isValidPackingPermutation (RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
 Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op. More...
 
ParseResult parseOptionalInterchange (OpAsmParser &parser, OperationState &result)
 
void printOptionalInterchange (OpAsmPrinter &p, ArrayRef< int64_t > interchangeVals)
 
template<typename OpTy >
DiagnosedSilenceableFailure doit (RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
 

Macro Definition Documentation

◆ DBGS

#define DBGS ( )    (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

Definition at line 50 of file LinalgTransformOps.cpp.

◆ DBGSNL

#define DBGSNL ( )    (llvm::dbgs() << "\n")

Definition at line 51 of file LinalgTransformOps.cpp.

◆ DEBUG_TYPE

#define DEBUG_TYPE   "linalg-transforms"

Definition at line 49 of file LinalgTransformOps.cpp.

◆ DOWNSCALE

#define DOWNSCALE (   trans)
Value:
{ \
FailureOr<LinalgOp> res = tryApply<trans>(target); \
if (succeeded(res)) { \
results.push_back(*res); \
} \
}
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56

◆ DOWNSCALE_CALL

#define DOWNSCALE_CALL (   a,
 
)    DownscaleSizeOneWindowed2DConvolution<a, b>

◆ DOWNSCALE_NORMAL

#define DOWNSCALE_NORMAL (   a,
 
)    DOWNSCALE(DOWNSCALE_CALL(a, b))

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

Definition at line 3278 of file LinalgTransformOps.cpp.

Function Documentation

◆ applyTilingToAll()

template<typename Range >
static LogicalResult applyTilingToAll ( RewriterBase rewriter,
Operation transformOp,
Range &&  payloadOps,
unsigned  numLoops,
transform::TransformResults transformResults,
function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)>  applyFn 
)
static

Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops.

Definition at line 212 of file LinalgTransformOps.cpp.

References mlir::Operation::emitError(), mlir::failed(), mlir::failure(), mlir::Operation::getOpResult(), mlir::RewriterBase::replaceOp(), mlir::transform::TransformResults::set(), mlir::OpBuilder::setInsertionPoint(), and mlir::success().

◆ cloneAndFuseFirstUse()

static Operation* cloneAndFuseFirstUse ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

◆ doit()

template<typename OpTy >
DiagnosedSilenceableFailure doit ( RewriterBase rewriter,
OpTy  target,
transform::ApplyToEachResultList results,
transform::TransformState state 
)

◆ forciblyReplaceReferencedPayloadOperation()

static void forciblyReplaceReferencedPayloadOperation ( TransformState state,
Operation payload,
Operation replacement 
)
static

Replaces payload with replacement in all handles stored in the state.

MUST NOT be used except for the case immediately below.

Definition at line 689 of file LinalgTransformOps.cpp.

◆ isValidPackingPermutation()

template<typename RelayoutOpTy >
bool isValidPackingPermutation ( RelayoutOpTy  op,
ArrayRef< int64_t >  permutation,
OuterOrInnerPerm  outerOrInnerPerm = OuterOrInnerPerm::Outer 
)

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op.

This is the case when thepermutationrank matches the rank expected by opandpermutationis itself a permutation vector. Return true if eitheroporpermutation` are empty to allow a simpler polymorphic implementation.

Definition at line 1420 of file LinalgTransformOps.cpp.

◆ packMatmulGreedily()

static FailureOr<PackResult> packMatmulGreedily ( RewriterBase rewriter,
LinalgOp  linalgOp,
ArrayRef< OpFoldResult mnkPackedSizes,
ArrayRef< int64_t >  mnkPaddedSizesNextMultipleOf,
ArrayRef< int64_t >  mnkOrder 
)
static

Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel dimensions and k is a proper reduction dimension.

Packing occurs by rewriting the op as a linalg.generic and calling linalg::pack by mnkPackedSizes. The order of the packed dimensions is customizable: the mnkOrder is a permutation of {0, 1, 2} to reorder {m, n, k} into one of the 8 possible forms. The outer dimensions of the operands are not permuted at this time, this is left for future work.

Definition at line 1217 of file LinalgTransformOps.cpp.

References mlir::bindDims(), mlir::bindSymbols(), mlir::AffineExpr::ceilDiv(), mlir::computePermutationVector(), DBGS, DBGSNL, mlir::failed(), mlir::linalg::generalizeNamedOp(), mlir::Builder::getContext(), mlir::Builder::getIndexAttr(), mlir::linalg::inferMatmulDims(), mlir::linalg::interchangeGenericOp(), mlir::isPermutationVector(), mlir::affine::makeComposedFoldedAffineApply(), mlir::RewriterBase::notifyMatchFailure(), mlir::linalg::pack(), mlir::Range::size, and mlir::succeeded().

◆ parseMultitileSizesTypes()

static ParseResult parseMultitileSizesTypes ( OpAsmParser parser,
Type targetType,
Type lowSizeType,
Type highSizeType,
Type splitPointType 
)
static

◆ parseOptionalInterchange()

ParseResult parseOptionalInterchange ( OpAsmParser parser,
OperationState result 
)

◆ printMultitileSizesTypes()

static void printMultitileSizesTypes ( OpAsmPrinter printer,
Operation op,
Type  targetType,
Type  lowSizeType,
Type  ,
Type   
)
static

Definition at line 1005 of file LinalgTransformOps.cpp.

◆ printOptionalInterchange()

void printOptionalInterchange ( OpAsmPrinter p,
ArrayRef< int64_t >  interchangeVals 
)

Definition at line 2505 of file LinalgTransformOps.cpp.

◆ replaceForAllWithNewSignature()

static Operation* replaceForAllWithNewSignature ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp,
TilingResult tileAndFuseResult,
int64_t  resultNumber,
SmallVector< OpFoldResult > &  offsets,
SmallVector< OpFoldResult > &  sizes 
)
static

◆ tileAndFuseFirstExtractUse()

static std::tuple<SmallVector<Operation *>, Operation *> tileAndFuseFirstExtractUse ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

Find the first "extract" user of producerOp and tile it right before its use.

The tiled op is fused under the containingOp. Return this fused op on success or nullptr if anything fails. If tiled op has uses that are dominated by containingOp, return a new containingOp with results of the fused op appended to results of the containingOp or nullptr if there are no dominated uses.

Definition at line 437 of file LinalgTransformOps.cpp.

References DBGS, diag(), mlir::failed(), mlir::Operation::getLoc(), replaceForAllWithNewSignature(), mlir::RewriterBase::replaceOp(), mlir::OpBuilder::setInsertionPoint(), and mlir::succeeded().

◆ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument()

static SmallVector<Operation *> tileAndFuseFirstExtractUseThroughContainingOpBlockArgument ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp, otherwise bail.

Then, find the first "extract" user of the tied block argument and tile it right before its "extract" use. The tiled op is fused under the containingOp. Return this fused op on success or nullptr if anything fails.

Definition at line 513 of file LinalgTransformOps.cpp.

References mlir::OpBuilder::clone(), DBGS, diag(), mlir::RewriterBase::eraseOp(), mlir::failed(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::OpOperand::getOperandNumber(), mlir::tensor::getOrCreateDestinations(), mlir::Value::getUsers(), mlir::IRMapping::map(), mlir::RewriterBase::replaceOp(), mlir::OpBuilder::setInsertionPoint(), mlir::Operation::setOperand(), mlir::succeeded(), and mlir::RewriterBase::updateRootInPlace().

◆ tryApply()

template<typename PatternTy , typename... Args>
static FailureOr<LinalgOp> tryApply ( Operation operation,
Args &&...  args 
)
static

Attempts to apply the pattern specified as template argument to the given operation.

The pattern is expected to have a returningMatchAndRewrite function that returns the "main" result or failure. Returns failure if the pattern failed to apply. Extra arguments are forwarded to the pattern constructor.

Definition at line 59 of file LinalgTransformOps.cpp.

◆ unpackSingleIndexResultPayloadOperations() [1/2]

static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations ( transform::TransformState state,
TransformOpInterface  transformOp,
SmallVector< OpFoldResult > &  result,
ArrayRef< OpFoldResult ofrs 
)
static

Assuming that ofr is an index attr or a transform dialect handle mapped to exactly one op with one index result, return that value.

Definition at line 86 of file LinalgTransformOps.cpp.

References diag(), mlir::Value::getLoc(), and mlir::transform::TransformState::getPayloadOps().

◆ unpackSingleIndexResultPayloadOperations() [2/2]

static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations ( transform::TransformState state,
TransformOpInterface  transformOp,
SmallVector< OpFoldResult > &  result,
Value  packedHandle 
)
static

Definition at line 126 of file LinalgTransformOps.cpp.