18 #include "llvm/Support/DebugLog.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/InterleavedRange.h"
24 #define DEBUG_TYPE "linalg-transforms"
34 if (!isa<linalg::LinalgOp>(current)) {
35 if (getFailurePropagationMode().value_or(
36 FailurePropagationMode::Propagate) ==
37 FailurePropagationMode::Propagate) {
38 return emitSilenceableError() <<
"expected a Linalg op";
41 LDBG() <<
"optional nested matcher expected a Linalg op";
47 auto scope = state.make_region_scope(getBodyRegion());
48 if (failed(state.mapBlockArgument(getBody()->getArgument(0),
53 for (
Operation &nested : getBody()->without_terminator()) {
55 state.applyTransform(cast<TransformOpInterface>(nested));
56 if (
diag.isDefiniteFailure())
62 assert(
diag.isSilenceableFailure());
63 if (getFailurePropagationMode().value_or(
64 FailurePropagationMode::Propagate) ==
65 FailurePropagationMode::Propagate) {
77 LDBG() <<
"optional nested matcher failed: " <<
diag.getMessage();
81 getBody()->getTerminator()->getOpOperands()) {
82 Operation *definingOp = terminatorOperand.get().getDefiningOp();
85 if (definingOp->
getBlock() != getBody())
90 undefinedOperands.push_back(&terminatorOperand);
94 auto filtered = llvm::make_filter_range(
95 getBody()->getTerminator()->getOpOperands(), [&](
OpOperand &opOperand) {
96 return !llvm::is_contained(undefinedOperands, &opOperand);
99 filtered, [](
OpOperand &opOperand) {
return opOperand.
get(); }));
101 for (
auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
114 void transform::MatchStructuredOp::getEffects(
122 if (getBody()->getNumArguments() != 1)
123 return emitOpError() <<
"expected one body argument";
124 if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).
getType())) {
125 return emitOpError() <<
"expected body argument to implement "
126 "TransformHandleTypeInterface";
128 for (
Operation &nested : getBody()->without_terminator()) {
129 if (isa<MatchOpInterface>(nested))
133 <<
"expects nested operations to implement MatchOpInterface";
134 diag.attachNote(nested.getLoc()) <<
"offending operation";
146 if (!isa_and_nonnull<MatchStructuredOp>(op->
getParentOp())) {
147 return op->
emitOpError() <<
"expects parent op to be '"
148 << MatchStructuredOp::getOperationName() <<
"'";
159 <<
"expected predicate to apply to the surrounding structured op";
171 auto linalgOp = cast<linalg::LinalgOp>(current);
172 if (std::optional<uint64_t> position = getReductionPosition()) {
176 return emitSilenceableError() <<
"could not match reduction";
178 if (combinerOps.size() != 1) {
179 return emitSilenceableError() <<
"reduction combiner is not a single op";
183 if (getPassthrough()) {
186 return emitSilenceableError() <<
"not a passthrough";
190 if (getElementwise()) {
192 return emitSilenceableError() <<
"not elementwise";
195 if (std::optional<ArrayAttr> contractionOps = getContraction()) {
198 llvm::raw_string_ostream os(message);
202 return elem->
getName().getStringRef() ==
203 cast<StringAttr>((*contractionOps)[0]).getValue() &&
205 cast<StringAttr>((*contractionOps)[1]).getValue();
210 return emitSilenceableError() <<
"contraction: " << message;
216 int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
217 getElementwise() + getContraction().has_value();
219 if (numOptions > 1) {
220 StringAttr attributeNames[] = {
221 getReductionPositionAttrName(), getPassthroughAttrName(),
222 getElementwiseAttrName(), getContractionAttrName()};
223 return emitOpError() <<
"only one of {" << llvm::interleaved(attributeNames)
227 if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
228 if (contractionAttr->size() != 2) {
229 return emitOpError() <<
"expects " << getContractionAttrName()
230 <<
" to contain two elements";
241 transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
244 FailureOr<linalg::ContractionDimensions> contractionDims =
246 if (failed(contractionDims))
247 return emitSilenceableError() <<
"could not infer contraction dimensions";
252 return llvm::to_vector(
253 llvm::map_range(values, [&](
unsigned value) ->
Attribute {
254 return builder.getI64IntegerAttr(value);
257 results.
setParams(cast<OpResult>(getBatch()),
258 makeI64Attrs(contractionDims->batch));
259 results.
setParams(cast<OpResult>(
getM()), makeI64Attrs(contractionDims->m));
260 results.
setParams(cast<OpResult>(
getN()), makeI64Attrs(contractionDims->n));
261 results.
setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
270 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
273 FailureOr<linalg::ConvolutionDimensions> convolutionDims =
275 if (failed(convolutionDims))
276 return emitSilenceableError() <<
"could not infer convolution dimensions";
281 return llvm::to_vector(
282 llvm::map_range(values, [&](
unsigned value) ->
Attribute {
283 return builder.getI64IntegerAttr(value);
286 results.
setParams(cast<OpResult>(getBatch()),
287 makeI64Attrs(convolutionDims->batch));
288 results.
setParams(cast<OpResult>(getOutputImage()),
289 makeI64Attrs(convolutionDims->outputImage));
290 results.
setParams(cast<OpResult>(getOutputChannel()),
291 makeI64Attrs(convolutionDims->outputChannel));
292 results.
setParams(cast<OpResult>(getFilterLoop()),
293 makeI64Attrs(convolutionDims->filterLoop));
294 results.
setParams(cast<OpResult>(getInputChannel()),
295 makeI64Attrs(convolutionDims->inputChannel));
296 results.
setParams(cast<OpResult>(getDepth()),
297 makeI64Attrs(convolutionDims->depth));
300 return llvm::to_vector(
301 llvm::map_range(values, [&](int64_t value) ->
Attribute {
302 return builder.getI64IntegerAttr(value);
305 results.
setParams(cast<OpResult>(getStrides()),
306 makeI64AttrsFromI64(convolutionDims->strides));
307 results.
setParams(cast<OpResult>(getDilations()),
308 makeI64AttrsFromI64(convolutionDims->dilations));
324 const char *message) {
325 for (int64_t value : list) {
326 if (llvm::any_of(reference, [&](
unsigned ref) {
327 return static_cast<int64_t
>(ref) == value;
343 auto linalgOp = cast<linalg::LinalgOp>(current);
346 if (!
diag.succeeded())
350 if (getParallel() || getReduction()) {
353 linalgOp.getParallelDims(reference);
354 else if (getReduction())
355 linalgOp.getReductionDims(reference);
359 getParallel() ?
"expects dimension #{0} to be parallel"
360 :
"expects dimension #{0} to be reduction");
361 if (!
diag.succeeded())
372 llvm::map_range(dimensions, [&](int64_t dim) ->
Attribute {
373 return builder.getI64IntegerAttr(ranges[dim]);
375 results.
setParams(cast<OpResult>(getResult()), captured);
383 getRawDimList(), op.getNumLoops(), dims);
384 if (
diag.isSilenceableFailure()) {
385 diag.attachNote(op->getLoc())
386 <<
"while considering dimensions of this payload operation";
392 if (getParallel() && getReduction()) {
393 return emitOpError() <<
"cannot request the same dimension to be both "
394 "parallel and reduction";
397 getIsInverted(), getIsAll());
405 transform::MatchStructuredElementalBitwidthOp::matchValue(
408 auto setupResult = [&](int64_t bitwidth) {
410 results.
setParams(cast<OpResult>(getResult()), {attr});
418 if (
auto shapedType = dyn_cast<ShapedType>(type)) {
419 if (shapedType.getElementType().isIntOrFloat())
420 return setupResult(shapedType.getElementTypeBitWidth());
422 return emitSilenceableError()
423 <<
"unsupported type for bitwidth extraction: " << type;
433 auto linalgOp = cast<linalg::LinalgOp>(current);
436 if (!
diag.succeeded())
440 operandMapping.reserve(positions.size());
441 for (int64_t position : positions) {
443 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
445 return emitSilenceableError() <<
"the indexing map for input #"
446 << position <<
" is not a permutation";
449 return emitSilenceableError()
450 <<
"the indexing map for input #" << position
451 <<
" is not a projected permutation";
458 if (isa<AffineMapParamType>(getResult().
getType())) {
463 Value operand = linalgOp.getDpsInputOperand(position)->get();
464 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
465 operandMapping.emplace_back(operand);
470 if (!operandProducer) {
471 return emitSilenceableError()
472 <<
"input #" << position <<
" is not produced by an operation";
474 operandMapping.emplace_back(operandProducer);
484 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
485 op.getNumDpsInputs(), positions);
486 if (
diag.isSilenceableFailure()) {
487 diag.attachNote(op->getLoc())
488 <<
"while considering DPS inputs of this payload operation";
495 template <
typename OpTy>
497 if (op.getPermutation() && op.getProjectedPermutation()) {
498 return op.emitOpError()
499 << op.getPermutationAttrName() <<
" and "
500 << op.getProjectedPermutationAttrName() <<
" are mutually exclusive";
502 if (op.getRawPositionList().size() > 1 && op.getResult()) {
503 return op.emitOpError()
504 <<
"cannot bind multiple inputs/inits to the same value";
514 getIsInverted(), getIsAll());
524 auto linalgOp = cast<linalg::LinalgOp>(current);
527 if (!
diag.succeeded())
531 operandMapping.reserve(positions.size());
532 for (int64_t position : positions) {
534 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
536 return emitSilenceableError() <<
"the indexing map for output(init) #"
537 << position <<
" is not a permutation";
540 return emitSilenceableError() <<
"the indexing map for output(init) #"
541 << position <<
" is not a permutation";
548 if (isa<AffineMapParamType>(getResult().
getType())) {
553 Value operand = linalgOp.getDpsInitOperand(position)->get();
554 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
555 operandMapping.emplace_back(operand);
560 if (!operandProducer) {
561 return emitSilenceableError() <<
"output(init) #" << position
562 <<
" is not produced by an operation";
564 operandMapping.emplace_back(operandProducer);
574 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
575 op.getNumDpsInits(), positions);
576 if (
diag.isSilenceableFailure()) {
577 diag.attachNote(op->getLoc())
578 <<
"while considering DPS inits (outputs) of this payload operation";
587 getIsInverted(), getIsAll());
595 transform::MatchStructuredNumInputsOp::matchOperation(
598 auto linalgOp = cast<linalg::LinalgOp>(current);
601 results.
setParams(cast<OpResult>(getResult()), {attr});
610 transform::MatchStructuredNumInitsOp::matchOperation(
613 auto linalgOp = cast<linalg::LinalgOp>(current);
616 results.
setParams(cast<OpResult>(getResult()), {attr});
627 auto linalgOp = cast<linalg::LinalgOp>(current);
628 int64_t numLoops = linalgOp.getNumLoops();
630 results.
setParams(cast<OpResult>(getRank()), {attr});
641 auto linalgOp = cast<linalg::LinalgOp>(op);
644 if (!
diag.succeeded())
647 Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
648 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
649 results.
setValues(cast<OpResult>(getResult()), {result});
654 return emitSilenceableError()
655 <<
"no users of the result #" << getPosition();
659 results.
set(cast<OpResult>(getResult()), {firstUser});
663 if (!llvm::hasSingleElement(result.
getUsers())) {
664 return emitSilenceableError()
665 <<
"more than one result user with single user requested";
667 results.
set(cast<OpResult>(getResult()), {firstUser});
675 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
677 auto rawPosition =
static_cast<int64_t
>(getPosition());
678 position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
679 if (position >= op.getNumDpsInits() || position < 0) {
680 return emitSilenceableError()
681 <<
"position " << rawPosition
682 <<
" overflows the number of results(ints) of the payload operation";
688 if ((getAny() || getSingle()) ^
689 isa<TransformHandleTypeInterface>(getResult().
getType())) {
690 return emitOpError() <<
"expects either the any/single keyword or the type "
691 "value handle result type";
693 if (getAny() && getSingle()) {
694 return emitOpError() <<
"'any' and 'single' are mutually exclusive";
703 void transform::MatchStructuredYieldOp::getEffects(
709 void transform::MatchStructuredYieldOp::build(
OpBuilder &builder,
714 #define GET_OP_CLASSES
715 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.