29struct ExtractSliceOpReplacementInterface
30 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
31 ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
32 SmallVector<Value> getNextOperands(Operation *op)
const {
33 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
36 return {extractSliceOp.getSource()};
40struct InsertSliceOpReplacementInterface
41 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
42 InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
43 SmallVector<Value> getNextOperands(Operation *op)
const {
44 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
47 return {insertSliceOp.getSource()};
51struct ReshapeOpReplacementInterface
52 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
53 ReshapeOpReplacementInterface, tensor::ReshapeOp> {
54 SmallVector<Value> getNextOperands(Operation *op)
const {
55 auto reshapeOp = cast<tensor::ReshapeOp>(op);
56 return {reshapeOp.getSource()};
60template <
typename ConcreteOp>
61struct ReassociativeReshapeOpReplacementInterface
62 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
63 ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
64 SmallVector<Value> getNextOperands(Operation *op)
const {
65 auto reshapeOp = cast<ConcreteOp>(op);
66 return {reshapeOp.getSrc()};
74 CollapseShapeOp::attachInterface<
75 ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
76 ExpandShapeOp::attachInterface<
77 ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
78 ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
79 InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
80 ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
88void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
93void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
98void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
103void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
108void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
113void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
118void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
123void transform::ApplyBubbleUpExtractSlicePatternsOp::populatePatterns(
128void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
131 Operation *producer = fusedOperand->get().getDefiningOp();
132 return producer && producer->
hasOneUse();
150void transform::TypeConversionCastShapeDynamicDimsOp::
152 bool ignoreDynamicInfo = getIgnoreDynamicInfo();
153 converter.addSourceMaterialization([ignoreDynamicInfo](
157 if (inputs.size() != 1) {
160 Value input = inputs[0];
161 if (!ignoreDynamicInfo &&
165 if (!tensor::CastOp::areCastCompatible(input.
getType(), resultType)) {
168 return tensor::CastOp::create(builder, loc, resultType, input).getResult();
170 converter.addTargetMaterialization([](
OpBuilder &builder,
Type resultType,
173 if (inputs.size() != 1) {
176 Value input = inputs[0];
177 if (!tensor::CastOp::areCastCompatible(input.
getType(), resultType)) {
180 return tensor::CastOp::create(builder, loc, resultType, input).getResult();
195 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
199 <<
"could not find " << i
200 <<
"-th enclosing loop";
201 diag.attachNote(
target->getLoc()) <<
"target op";
204 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
209 if (
auto padOp = dyn_cast<tensor::PadOp>(
target)) {
211 }
else if (
auto emptyOp = dyn_cast<tensor::EmptyOp>(
target)) {
215 <<
"unsupported target op";
216 diag.attachNote(
target->getLoc()) <<
"target op";
221 emitSilenceableError() <<
"could not make target op loop-independent";
222 diag.attachNote(
target->getLoc()) <<
"target op";
235class TensorTransformDialectExtension
237 TensorTransformDialectExtension> {
244 declareGeneratedDialect<affine::AffineDialect>();
245 declareGeneratedDialect<tensor::TensorDialect>();
247 registerTransformOps<
249#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
255#define GET_OP_CLASSES
256#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
void registerTransformDialectExtension(DialectRegistry ®istry)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that decompose tensor.concat into tensor.empty of a tensor of the co...
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
std::function< bool(OpOperand *)> ControlFoldFn
void registerFindPayloadReplacementOpInterfaceExternalModels(DialectRegistry ®istry)
bool isCastLikeExtractSliceOp(ExtractSliceOp op)
A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
FailureOr< Value > buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, ValueRange independencies)
Build a new tensor::PadOp with low/high padding that is independent of all given independencies.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns, const ControlFoldFn &controlFn)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns