MLIR  20.0.0git
Macros | Functions
LinalgTransformOps.cpp File Reference
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <type_traits>
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"

Go to the source code of this file.

Macros

#define DEBUG_TYPE   "linalg-transforms"
 
#define DBGS()   (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
#define DBGSNL()   (llvm::dbgs() << "\n")
 
#define LDBG(X)   LLVM_DEBUG(DBGS() << (X) << "\n")
 
#define DOWNSCALE(trans)
 
#define DOWNSCALE_CALL(a, b)   DownscaleSizeOneWindowed2DConvolution<a, b>
 
#define DOWNSCALE_NORMAL(a, b)   DOWNSCALE(DOWNSCALE_CALL(a, b))
 
#define GET_OP_CLASSES
 

Functions

template<typename PatternTy , typename... Args>
static FailureOr< LinalgOp > tryApply (Operation *operation, Args &&...args)
 Attempts to apply the pattern specified as template argument to the given operation. More...
 
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
 Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value. More...
 
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, Value packedHandle)
 
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults (TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
 When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically inferred. More...
 
template<typename Range >
static LogicalResult applyTilingToAll (RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
 Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops. More...
 
static OperationreplaceForAllWithNewSignature (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
 Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op. More...
 
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 Find the first "extract" user of producerOp and tile it right before its use. More...
 
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp, otherwise bail. More...
 
static OperationcloneAndFuseFirstUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 
static void printMultitileSizesTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
 
static ParseResult parseMultitileSizesTypes (OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
 
template<typename RelayoutOpTy >
bool isValidPackingPermutation (RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
 Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op. More...
 
static void printContinuousTileSizeTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
 
static ParseResult parseContinuousTileSizeTypes (OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
 
static SmallVector< OpFoldResultnormalizeUpperBounds (RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
 Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound. More...
 
static SmallVector< ValuedenormalizeIndVar (RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
 When a loop is normalized, the uses of the induction variable within the loop need to replaced with original_lb + old_iv * original_step. More...
 
static scf::ForallOp normalizeForallLoopOp (RewriterBase &rewriter, scf::ForallOp loop)
 Given a scf.forall loop return a loop op with the loop bounds normalized. More...
 
template<typename OpTy >
DiagnosedSilenceableFailure doit (RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
 

Macro Definition Documentation

◆ DBGS

#define DBGS ( )    (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

Definition at line 55 of file LinalgTransformOps.cpp.

◆ DBGSNL

#define DBGSNL ( )    (llvm::dbgs() << "\n")

Definition at line 56 of file LinalgTransformOps.cpp.

◆ DEBUG_TYPE

#define DEBUG_TYPE   "linalg-transforms"

Definition at line 54 of file LinalgTransformOps.cpp.

◆ DOWNSCALE

#define DOWNSCALE (   trans)
Value:
{ \
FailureOr<LinalgOp> res = tryApply<trans>(target); \
if (succeeded(res)) { \
results.push_back(*res); \
return DiagnosedSilenceableFailure::success(); \
} \
}

◆ DOWNSCALE_CALL

#define DOWNSCALE_CALL (   a,
 
)    DownscaleSizeOneWindowed2DConvolution<a, b>

◆ DOWNSCALE_NORMAL

#define DOWNSCALE_NORMAL (   a,
 
)    DOWNSCALE(DOWNSCALE_CALL(a, b))

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

Definition at line 3936 of file LinalgTransformOps.cpp.

◆ LDBG

#define LDBG (   X)    LLVM_DEBUG(DBGS() << (X) << "\n")

Definition at line 57 of file LinalgTransformOps.cpp.

Function Documentation

◆ applyTilingToAll()

template<typename Range >
static LogicalResult applyTilingToAll ( RewriterBase rewriter,
Operation transformOp,
Range &&  payloadOps,
unsigned  numLoops,
transform::TransformResults transformResults,
function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)>  applyFn 
)
static

Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops.

Definition at line 517 of file LinalgTransformOps.cpp.

References mlir::Operation::emitError(), mlir::RewriterBase::eraseOp(), mlir::Operation::getOpResult(), mlir::RewriterBase::replaceAllUsesWith(), mlir::transform::TransformResults::set(), and mlir::OpBuilder::setInsertionPoint().

◆ cloneAndFuseFirstUse()

static Operation* cloneAndFuseFirstUse ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

◆ denormalizeIndVar()

static SmallVector<Value> denormalizeIndVar ( RewriterBase rewriter,
Location  loc,
ValueRange  ivs,
ArrayRef< OpFoldResult lbs,
ArrayRef< OpFoldResult steps 
)
static

When a loop is normalized, the uses of the induction variable within the loop need to replaced with original_lb + old_iv * original_step.

Definition at line 3207 of file LinalgTransformOps.cpp.

References mlir::bindDims(), mlir::bindSymbols(), mlir::Builder::getContext(), mlir::getValueOrCreateConstantIndexOp(), and mlir::affine::makeComposedFoldedAffineApply().

Referenced by normalizeForallLoopOp().

◆ doit()

template<typename OpTy >
DiagnosedSilenceableFailure doit ( RewriterBase rewriter,
OpTy  target,
transform::ApplyToEachResultList results,
transform::TransformState state 
)

◆ isValidPackingPermutation()

template<typename RelayoutOpTy >
bool isValidPackingPermutation ( RelayoutOpTy  op,
ArrayRef< int64_t >  permutation,
OuterOrInnerPerm  outerOrInnerPerm = OuterOrInnerPerm::Outer 
)

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op.

This is the case when thepermutationrank matches the rank expected by opandpermutationis itself a permutation vector. Return true if eitheroporpermutation` are empty to allow a simpler polymorphic implementation.

Definition at line 1622 of file LinalgTransformOps.cpp.

References mlir::isPermutationVector().

◆ normalizeForallLoopOp()

static scf::ForallOp normalizeForallLoopOp ( RewriterBase rewriter,
scf::ForallOp  loop 
)
static

Given a scf.forall loop return a loop op with the loop bounds normalized.

TODO: Replace this with a general utility to normalize scf.forall. At the time of writing, this wasnt done since adding this to scf dialect would disallow using of affine.apply operations due to cyclic dependencies. To avoid churn in lit tests with the change this was added with, defer that to a follow up.

Definition at line 3234 of file LinalgTransformOps.cpp.

References mlir::OpBuilder::create(), denormalizeIndVar(), mlir::Builder::getIndexAttr(), mlir::isConstantIntValue(), mlir::RewriterBase::mergeBlocks(), normalizeUpperBounds(), mlir::RewriterBase::replaceOp(), and mlir::OpBuilder::setInsertionPointToStart().

◆ normalizeUpperBounds()

static SmallVector<OpFoldResult> normalizeUpperBounds ( RewriterBase rewriter,
Location  loc,
ArrayRef< OpFoldResult lbs,
ArrayRef< OpFoldResult ubs,
ArrayRef< OpFoldResult steps 
)
static

Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.

Definition at line 3190 of file LinalgTransformOps.cpp.

References mlir::bindSymbols(), mlir::Builder::getContext(), and mlir::affine::makeComposedFoldedAffineApply().

Referenced by normalizeForallLoopOp().

◆ parseContinuousTileSizeTypes()

static ParseResult parseContinuousTileSizeTypes ( OpAsmParser parser,
Type targetType,
Type tileSizesType,
Type chunkSizesType 
)
static

◆ parseMultitileSizesTypes()

static ParseResult parseMultitileSizesTypes ( OpAsmParser parser,
Type targetType,
Type lowSizeType,
Type highSizeType,
Type splitPointType 
)
static

◆ printContinuousTileSizeTypes()

static void printContinuousTileSizeTypes ( OpAsmPrinter printer,
Operation op,
Type  targetType,
Type  tile_sizes,
Type   
)
static

◆ printMultitileSizesTypes()

static void printMultitileSizesTypes ( OpAsmPrinter printer,
Operation op,
Type  targetType,
Type  lowSizeType,
Type  ,
Type   
)
static

◆ reifyMixedParamAndHandleResults()

static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults ( TransformState state,
TransformOpInterface &  transformOp,
ArrayRef< OpFoldResult mixedResults,
SmallVectorImpl< int64_t > &  reified 
)
static

When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically inferred.

If a result is a Value then it must be either a ParamType or a handle to an a constant like op.

Definition at line 179 of file LinalgTransformOps.cpp.

◆ replaceForAllWithNewSignature()

static Operation* replaceForAllWithNewSignature ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp,
TilingResult tileAndFuseResult,
int64_t  resultNumber,
SmallVector< OpFoldResult > &  offsets,
SmallVector< OpFoldResult > &  sizes 
)
static

◆ tileAndFuseFirstExtractUse()

static std::tuple<SmallVector<Operation *>, Operation *> tileAndFuseFirstExtractUse ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

Find the first "extract" user of producerOp and tile it right before its use.

The tiled op is fused under the containingOp. Return this fused op on success or nullptr if anything fails. If tiled op has uses that are dominated by containingOp, return a new containingOp with results of the fused op appended to results of the containingOp or nullptr if there are no dominated uses.

Definition at line 718 of file LinalgTransformOps.cpp.

References DBGS, diag(), mlir::Operation::getLoc(), replaceForAllWithNewSignature(), mlir::RewriterBase::replaceOp(), and mlir::OpBuilder::setInsertionPoint().

◆ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument()

static SmallVector<Operation *> tileAndFuseFirstExtractUseThroughContainingOpBlockArgument ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp, otherwise bail.

Then, find the first "extract" user of the tied block argument and tile it right before its "extract" use. The tiled op is fused under the containingOp. Return this fused op on success or nullptr if anything fails.

Definition at line 800 of file LinalgTransformOps.cpp.

References mlir::OpBuilder::clone(), DBGS, diag(), mlir::RewriterBase::eraseOp(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::OpOperand::getOperandNumber(), mlir::tensor::getOrCreateDestinations(), mlir::Value::getUsers(), mlir::IRMapping::map(), mlir::RewriterBase::modifyOpInPlace(), mlir::RewriterBase::replaceOp(), mlir::OpBuilder::setInsertionPoint(), and mlir::Operation::setOperand().

◆ tryApply()

template<typename PatternTy , typename... Args>
static FailureOr<LinalgOp> tryApply ( Operation operation,
Args &&...  args 
)
static

Attempts to apply the pattern specified as template argument to the given operation.

The pattern is expected to have a returningMatchAndRewrite function that returns the "main" result or failure. Returns failure if the pattern failed to apply. Extra arguments are forwarded to the pattern constructor.

Definition at line 65 of file LinalgTransformOps.cpp.

References mlir::Operation::getContext().

◆ unpackSingleIndexResultPayloadOperations() [1/2]

static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations ( transform::TransformState state,
TransformOpInterface  transformOp,
SmallVector< OpFoldResult > &  result,
ArrayRef< OpFoldResult ofrs 
)
static

Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value.

Definition at line 93 of file LinalgTransformOps.cpp.

◆ unpackSingleIndexResultPayloadOperations() [2/2]

static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations ( transform::TransformState state,
TransformOpInterface  transformOp,
SmallVector< OpFoldResult > &  result,
Value  packedHandle 
)
static

Definition at line 145 of file LinalgTransformOps.cpp.