18#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/Support/DebugLog.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/Support/InterleavedRange.h"
25#define DEBUG_TYPE "linalg-transforms"
35 if (!isa<linalg::LinalgOp>(current)) {
36 if (getFailurePropagationMode().value_or(
37 FailurePropagationMode::Propagate) ==
38 FailurePropagationMode::Propagate) {
39 return emitSilenceableError() <<
"expected a Linalg op";
42 LDBG() <<
"optional nested matcher expected a Linalg op";
50 MappedValue(current)))) {
54 for (
Operation &nested : getBody()->without_terminator()) {
57 if (
diag.isDefiniteFailure())
63 assert(
diag.isSilenceableFailure());
64 if (getFailurePropagationMode().value_or(
65 FailurePropagationMode::Propagate) ==
66 FailurePropagationMode::Propagate) {
78 LDBG() <<
"optional nested matcher failed: " <<
diag.getMessage();
82 getBody()->getTerminator()->getOpOperands()) {
83 Operation *definingOp = terminatorOperand.get().getDefiningOp();
86 if (definingOp->
getBlock() != getBody())
91 undefinedOperands.push_back(&terminatorOperand);
95 auto filtered = llvm::make_filter_range(
96 getBody()->getTerminator()->getOpOperands(), [&](
OpOperand &opOperand) {
97 return !llvm::is_contained(undefinedOperands, &opOperand);
100 filtered, [](
OpOperand &opOperand) {
return opOperand.
get(); });
101 detail::prepareValueMappings(mappings, definedOperands, state);
102 for (
auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
111 detail::forwardTerminatorOperands(getBody(), state, results);
115void transform::MatchStructuredOp::getEffects(
122LogicalResult transform::MatchStructuredOp::verify() {
123 if (getBody()->getNumArguments() != 1)
124 return emitOpError() <<
"expected one body argument";
125 if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).
getType())) {
126 return emitOpError() <<
"expected body argument to implement "
127 "TransformHandleTypeInterface";
129 for (
Operation &nested : getBody()->without_terminator()) {
130 if (isa<MatchOpInterface>(nested))
134 <<
"expects nested operations to implement MatchOpInterface";
135 diag.attachNote(nested.getLoc()) <<
"offending operation";
147 if (!isa_and_nonnull<MatchStructuredOp>(op->
getParentOp())) {
148 return op->
emitOpError() <<
"expects parent op to be '"
149 << MatchStructuredOp::getOperationName() <<
"'";
160 <<
"expected predicate to apply to the surrounding structured op";
172 auto linalgOp = cast<linalg::LinalgOp>(current);
173 if (std::optional<uint64_t> position = getReductionPosition()) {
177 return emitSilenceableError() <<
"could not match reduction";
179 if (combinerOps.size() != 1) {
180 return emitSilenceableError() <<
"reduction combiner is not a single op";
184 if (getPassthrough()) {
187 return emitSilenceableError() <<
"not a passthrough";
191 if (getElementwise()) {
193 return emitSilenceableError() <<
"not elementwise";
196 if (std::optional<ArrayAttr> contractionOps = getContraction()) {
199 llvm::raw_string_ostream os(message);
204 cast<StringAttr>((*contractionOps)[0]).getValue() &&
206 cast<StringAttr>((*contractionOps)[1]).getValue();
211 return emitSilenceableError() <<
"contraction: " << message;
216LogicalResult transform::MatchStructuredBodyOp::verify() {
217 int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
218 getElementwise() + getContraction().has_value();
220 if (numOptions > 1) {
221 StringAttr attributeNames[] = {
222 getReductionPositionAttrName(), getPassthroughAttrName(),
223 getElementwiseAttrName(), getContractionAttrName()};
224 return emitOpError() <<
"only one of {" << llvm::interleaved(attributeNames)
228 if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
229 if (contractionAttr->size() != 2) {
230 return emitOpError() <<
"expects " << getContractionAttrName()
231 <<
" to contain two elements";
242transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
245 FailureOr<linalg::ContractionDimensions> contractionDims =
247 if (
failed(contractionDims))
248 return emitSilenceableError() <<
"could not infer contraction dimensions";
253 return llvm::map_to_vector(values, [&](
unsigned value) ->
Attribute {
254 return builder.getI64IntegerAttr(value);
257 results.
setParams(cast<OpResult>(getBatch()),
258 makeI64Attrs(contractionDims->batch));
259 results.
setParams(cast<OpResult>(
getM()), makeI64Attrs(contractionDims->m));
260 results.
setParams(cast<OpResult>(
getN()), makeI64Attrs(contractionDims->n));
261 results.
setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
270transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
273 FailureOr<linalg::ConvolutionDimensions> convolutionDims =
275 if (
failed(convolutionDims))
276 return emitSilenceableError() <<
"could not infer convolution dimensions";
281 return llvm::map_to_vector(values, [&](
unsigned value) ->
Attribute {
282 return builder.getI64IntegerAttr(value);
285 results.
setParams(cast<OpResult>(getBatch()),
286 makeI64Attrs(convolutionDims->batch));
287 results.
setParams(cast<OpResult>(getOutputImage()),
288 makeI64Attrs(convolutionDims->outputImage));
289 results.
setParams(cast<OpResult>(getOutputChannel()),
290 makeI64Attrs(convolutionDims->outputChannel));
291 results.
setParams(cast<OpResult>(getFilterLoop()),
292 makeI64Attrs(convolutionDims->filterLoop));
293 results.
setParams(cast<OpResult>(getInputChannel()),
294 makeI64Attrs(convolutionDims->inputChannel));
295 results.
setParams(cast<OpResult>(getDepth()),
296 makeI64Attrs(convolutionDims->depth));
300 return builder.getI64IntegerAttr(value);
303 results.
setParams(cast<OpResult>(getStrides()),
304 makeI64AttrsFromI64(convolutionDims->strides));
305 results.
setParams(cast<OpResult>(getDilations()),
306 makeI64AttrsFromI64(convolutionDims->dilations));
322 const char *message) {
324 if (llvm::any_of(reference, [&](
unsigned ref) {
325 return static_cast<int64_t>(ref) == value;
341 auto linalgOp = cast<linalg::LinalgOp>(current);
344 if (!
diag.succeeded())
348 if (getParallel() || getReduction()) {
351 linalgOp.getParallelDims(reference);
352 else if (getReduction())
353 linalgOp.getReductionDims(reference);
357 getParallel() ?
"expects dimension #{0} to be parallel"
358 :
"expects dimension #{0} to be reduction");
359 if (!
diag.succeeded())
371 return builder.getI64IntegerAttr(ranges[dim]);
373 results.
setParams(cast<OpResult>(getResult()), captured);
381 getRawDimList(), op.getNumLoops(), dims);
382 if (
diag.isSilenceableFailure()) {
383 diag.attachNote(op->getLoc())
384 <<
"while considering dimensions of this payload operation";
389LogicalResult transform::MatchStructuredDimOp::verify() {
390 if (getParallel() && getReduction()) {
391 return emitOpError() <<
"cannot request the same dimension to be both "
392 "parallel and reduction";
395 getIsInverted(), getIsAll());
403transform::MatchStructuredElementalBitwidthOp::matchValue(
406 auto setupResult = [&](
int64_t bitwidth) {
408 results.
setParams(cast<OpResult>(getResult()), {attr});
416 if (
auto shapedType = dyn_cast<ShapedType>(type)) {
417 if (shapedType.getElementType().isIntOrFloat())
418 return setupResult(shapedType.getElementTypeBitWidth());
420 return emitSilenceableError()
421 <<
"unsupported type for bitwidth extraction: " << type;
431 auto linalgOp = cast<linalg::LinalgOp>(current);
434 if (!
diag.succeeded())
438 operandMapping.reserve(positions.size());
439 for (
int64_t position : positions) {
441 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
443 return emitSilenceableError() <<
"the indexing map for input #"
444 << position <<
" is not a permutation";
447 return emitSilenceableError()
448 <<
"the indexing map for input #" << position
449 <<
" is not a projected permutation";
456 if (isa<AffineMapParamType>(getResult().
getType())) {
457 operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
461 Value operand = linalgOp.getDpsInputOperand(position)->get();
462 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
463 operandMapping.emplace_back(operand);
468 if (!operandProducer) {
469 return emitSilenceableError()
470 <<
"input #" << position <<
" is not produced by an operation";
472 operandMapping.emplace_back(operandProducer);
482 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
483 op.getNumDpsInputs(), positions);
484 if (
diag.isSilenceableFailure()) {
485 diag.attachNote(op->getLoc())
486 <<
"while considering DPS inputs of this payload operation";
493template <
typename OpTy>
495 if (op.getPermutation() && op.getProjectedPermutation()) {
496 return op.emitOpError()
497 << op.getPermutationAttrName() <<
" and "
498 << op.getProjectedPermutationAttrName() <<
" are mutually exclusive";
500 if (op.getRawPositionList().size() > 1 && op.getResult()) {
501 return op.emitOpError()
502 <<
"cannot bind multiple inputs/inits to the same value";
508LogicalResult transform::MatchStructuredInputOp::verify() {
512 getIsInverted(), getIsAll());
522 auto linalgOp = cast<linalg::LinalgOp>(current);
525 if (!
diag.succeeded())
529 operandMapping.reserve(positions.size());
530 for (
int64_t position : positions) {
532 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
534 return emitSilenceableError() <<
"the indexing map for output(init) #"
535 << position <<
" is not a permutation";
538 return emitSilenceableError() <<
"the indexing map for output(init) #"
539 << position <<
" is not a permutation";
546 if (isa<AffineMapParamType>(getResult().
getType())) {
547 operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
551 Value operand = linalgOp.getDpsInitOperand(position)->get();
552 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
553 operandMapping.emplace_back(operand);
558 if (!operandProducer) {
559 return emitSilenceableError() <<
"output(init) #" << position
560 <<
" is not produced by an operation";
562 operandMapping.emplace_back(operandProducer);
572 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
573 op.getNumDpsInits(), positions);
574 if (
diag.isSilenceableFailure()) {
575 diag.attachNote(op->getLoc())
576 <<
"while considering DPS inits (outputs) of this payload operation";
581LogicalResult transform::MatchStructuredInitOp::verify() {
585 getIsInverted(), getIsAll());
593transform::MatchStructuredNumInputsOp::matchOperation(
596 auto linalgOp = cast<linalg::LinalgOp>(current);
599 results.
setParams(cast<OpResult>(getResult()), {attr});
608transform::MatchStructuredNumInitsOp::matchOperation(
611 auto linalgOp = cast<linalg::LinalgOp>(current);
614 results.
setParams(cast<OpResult>(getResult()), {attr});
625 auto linalgOp = cast<linalg::LinalgOp>(current);
626 int64_t numLoops = linalgOp.getNumLoops();
628 results.
setParams(cast<OpResult>(getRank()), {attr});
639 auto linalgOp = cast<linalg::LinalgOp>(op);
642 if (!
diag.succeeded())
645 Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
646 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
651 if (
result.getUsers().empty()) {
652 return emitSilenceableError()
653 <<
"no users of the result #" << getPosition();
657 results.
set(cast<OpResult>(getResult()), {firstUser});
661 if (!llvm::hasSingleElement(
result.getUsers())) {
662 return emitSilenceableError()
663 <<
"more than one result user with single user requested";
665 results.
set(cast<OpResult>(getResult()), {firstUser});
673transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
675 auto rawPosition =
static_cast<int64_t>(getPosition());
676 position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
677 if (position >= op.getNumDpsInits() || position < 0) {
678 return emitSilenceableError()
679 <<
"position " << rawPosition
680 <<
" overflows the number of results(ints) of the payload operation";
685LogicalResult transform::MatchStructuredResultOp::verify() {
686 if ((getAny() || getSingle()) ^
687 isa<TransformHandleTypeInterface>(getResult().
getType())) {
688 return emitOpError() <<
"expects either the any/single keyword or the type "
689 "value handle result type";
691 if (getAny() && getSingle()) {
692 return emitOpError() <<
"'any' and 'single' are mutually exclusive";
701void transform::MatchStructuredYieldOp::getEffects(
707void transform::MatchStructuredYieldOp::build(
OpBuilder &builder,
712#define GET_OP_CLASSES
713#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
This represents an operation in an abstracted form, suitable for use with the builder APIs.