1 //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
22 using namespace mlir;
24 #define DEBUG_TYPE "linalg-transforms"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 //===----------------------------------------------------------------------===//
28 // StructuredMatchOp
29 //===----------------------------------------------------------------------===//
31 DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
32  Operation *current, transform::TransformResults &results,
34  // First, check if the payload operation is a structured Linalg operation.
35  if (!isa<linalg::LinalgOp>(current)) {
36  if (getFailurePropagationMode().value_or(
37  FailurePropagationMode::Propagate) ==
38  FailurePropagationMode::Propagate) {
39  return emitSilenceableError() << "expected a Linalg op";
40  }
41  // If errors are suppressed, succeed and set all results to empty lists.
42  LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
43  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
45  }
47  // Bind `current` to the block argument.
48  auto scope = state.make_region_scope(getBodyRegion());
49  if (failed(state.mapBlockArgument(getBody()->getArgument(0),
50  MappedValue(current)))) {
52  }
54  for (Operation &nested : getBody()->without_terminator()) {
56  state.applyTransform(cast<TransformOpInterface>(nested));
57  if (diag.isDefiniteFailure())
58  return diag;
59  if (diag.succeeded())
60  continue;
62  // If propagating errors, do this immediately.
63  assert(diag.isSilenceableFailure());
64  if (getFailurePropagationMode().value_or(
65  FailurePropagationMode::Propagate) ==
66  FailurePropagationMode::Propagate) {
67  return diag;
68  }
70  // If suppressing errors, print the message into the debug stream before
71  // silencing it. Then set all results value that are already known.
72  // Results come from the terminator operands, which may be defined in the
73  // (single) block of this operation or above it. When they are defined
74  // above, they are known to be mapped at this point per SSA dominance.
75  // When they are defined in this block, we additionally check if we have
76  // already applied the operation that defines them. If not, the
77  // corresponding results will be set to empty lists.
78  LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
79  << "\n");
80  (void)diag.silence();
81  SmallVector<OpOperand *> undefinedOperands;
82  for (OpOperand &terminatorOperand :
83  getBody()->getTerminator()->getOpOperands()) {
84  Operation *definingOp = terminatorOperand.get().getDefiningOp();
85  if (!definingOp)
86  continue;
87  if (definingOp->getBlock() != getBody())
88  continue;
89  if (definingOp->isBeforeInBlock(&nested))
90  continue;
92  undefinedOperands.push_back(&terminatorOperand);
93  }
96  auto filtered = llvm::make_filter_range(
97  getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
98  return !llvm::is_contained(undefinedOperands, &opOperand);
99  });
100  SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
101  filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
102  detail::prepareValueMappings(mappings, definedOperands, state);
103  for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
104  results.setMappedValues(getResults()[operand.getOperandNumber()],
105  mapping);
106  }
107  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
109  }
111  // Set the results.
112  detail::forwardTerminatorOperands(getBody(), state, results);
114 }
116 void transform::MatchStructuredOp::getEffects(
118  onlyReadsHandle(getCurrentMutable(), effects);
119  onlyReadsPayload(effects);
120  producesHandle(getOperation()->getOpResults(), effects);
121 }
123 LogicalResult transform::MatchStructuredOp::verify() {
124  if (getBody()->getNumArguments() != 1)
125  return emitOpError() << "expected one body argument";
126  if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
127  return emitOpError() << "expected body argument to implement "
128  "TransformHandleTypeInterface";
129  }
130  for (Operation &nested : getBody()->without_terminator()) {
131  if (isa<MatchOpInterface>(nested))
132  continue;
134  emitOpError()
135  << "expects nested operations to implement MatchOpInterface";
136  diag.attachNote(nested.getLoc()) << "offending operation";
137  return diag;
138  }
139  return success();
140 }
142 //===----------------------------------------------------------------------===//
143 // StructuredOpPredicateOpTrait
144 //===----------------------------------------------------------------------===//
147  Operation *op, Value structuredOpHandle) {
148  if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
149  return op->emitOpError() << "expects parent op to be '"
150  << MatchStructuredOp::getOperationName() << "'";
151  }
153  // Bail out here, let the verifier of the parent complain.
154  Operation *parent = op->getParentOp();
155  if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
156  parent->getRegion(0).front().getNumArguments() < 1)
157  return success();
159  if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
160  return op->emitOpError()
161  << "expected predicate to apply to the surrounding structured op";
162  }
163  return success();
164 }
166 //===----------------------------------------------------------------------===//
167 // MatchStructuredBodyOp
168 //===----------------------------------------------------------------------===//
170 DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
171  Operation *current, transform::TransformResults &results,
172  transform::TransformState &state) {
173  auto linalgOp = cast<linalg::LinalgOp>(current);
174  if (std::optional<uint64_t> position = getReductionPosition()) {
175  SmallVector<Operation *> combinerOps;
176  if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
177  combinerOps)) {
178  return emitSilenceableError() << "could not match reduction";
179  }
180  if (combinerOps.size() != 1) {
181  return emitSilenceableError() << "reduction combiner is not a single op";
182  }
184  }
185  if (getPassthrough()) {
186  Block &body = linalgOp->getRegion(0).front();
187  if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
188  return emitSilenceableError() << "not a passthrough";
189  }
191  }
192  if (getElementwise()) {
193  if (!isElementwise(linalgOp))
194  return emitSilenceableError() << "not elementwise";
196  }
197  if (std::optional<ArrayAttr> contractionOps = getContraction()) {
198  Block &body = linalgOp->getRegion(0).front();
199  std::string message;
200  llvm::raw_string_ostream os(message);
201  bool result = linalg::detail::isContractionBody(
202  body,
203  [&](Operation *elem, Operation *red) {
204  return elem->getName().getStringRef() ==
205  cast<StringAttr>((*contractionOps)[0]).getValue() &&
206  red->getName().getStringRef() ==
207  cast<StringAttr>((*contractionOps)[1]).getValue();
208  },
209  os);
210  if (result)
212  return emitSilenceableError() << "contraction: " << message;
213  }
214  return emitDefiniteFailure() << "unknown body condition";
215 }
218  int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
219  getElementwise() + getContraction().has_value();
221  if (numOptions > 1) {
222  std::string attributeNames;
223  llvm::raw_string_ostream os(attributeNames);
224  llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
225  getPassthroughAttrName(),
226  getElementwiseAttrName(),
227  getContractionAttrName()},
228  os);
229  return emitOpError() << "only one of {" << attributeNames << "} is allowed";
230  }
232  if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
233  if (contractionAttr->size() != 2) {
234  return emitOpError() << "expects " << getContractionAttrName()
235  << " to contain two elements";
236  }
237  }
238  return success();
239 }
241 //===----------------------------------------------------------------------===//
242 // MatchStructuredClassifyContractionDimsOp
243 //===----------------------------------------------------------------------===//
246 transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
247  Operation *current, transform::TransformResults &results,
248  transform::TransformState &state) {
249  FailureOr<linalg::ContractionDimensions> contractionDims =
250  linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
251  if (failed(contractionDims))
252  return emitSilenceableError() << "could not infer contraction dimensions";
254  MLIRContext *context = current->getContext();
255  Builder builder(context);
256  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
257  return llvm::to_vector(
258  llvm::map_range(values, [&](unsigned value) -> Attribute {
259  return builder.getI64IntegerAttr(value);
260  }));
261  };
262  results.setParams(cast<OpResult>(getBatch()),
263  makeI64Attrs(contractionDims->batch));
264  results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
265  results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
266  results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
268 }
270 //===----------------------------------------------------------------------===//
271 // MatchStructuredClassifyConvolutionDimsOp
272 //===----------------------------------------------------------------------===//
275 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
276  Operation *current, transform::TransformResults &results,
277  transform::TransformState &state) {
278  FailureOr<linalg::ConvolutionDimensions> convolutionDims =
279  linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
280  if (failed(convolutionDims))
281  return emitSilenceableError() << "could not infer convolution dimensions";
283  MLIRContext *context = current->getContext();
284  Builder builder(context);
285  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
286  return llvm::to_vector(
287  llvm::map_range(values, [&](unsigned value) -> Attribute {
288  return builder.getI64IntegerAttr(value);
289  }));
290  };
291  results.setParams(cast<OpResult>(getBatch()),
292  makeI64Attrs(convolutionDims->batch));
293  results.setParams(cast<OpResult>(getOutputImage()),
294  makeI64Attrs(convolutionDims->outputImage));
295  results.setParams(cast<OpResult>(getOutputChannel()),
296  makeI64Attrs(convolutionDims->outputChannel));
297  results.setParams(cast<OpResult>(getFilterLoop()),
298  makeI64Attrs(convolutionDims->filterLoop));
299  results.setParams(cast<OpResult>(getInputChannel()),
300  makeI64Attrs(convolutionDims->inputChannel));
301  results.setParams(cast<OpResult>(getDepth()),
302  makeI64Attrs(convolutionDims->depth));
304  auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
305  return llvm::to_vector(
306  llvm::map_range(values, [&](int64_t value) -> Attribute {
307  return builder.getI64IntegerAttr(value);
308  }));
309  };
310  results.setParams(cast<OpResult>(getStrides()),
311  makeI64AttrsFromI64(convolutionDims->strides));
312  results.setParams(cast<OpResult>(getDilations()),
313  makeI64AttrsFromI64(convolutionDims->dilations));
315 }
317 //===----------------------------------------------------------------------===//
318 // Utilities for structured match predicates.
319 //===----------------------------------------------------------------------===//
321 /// Checks if all values from `list` are also contained in `reference`. Returns
322 /// a silenceable error with the given message at the given location when it is
323 /// not the case. The error message must contain the "{0}" placeholder that
324 /// will be substituted with the value from `list` that is not contained in
325 /// `reference`.
327  ArrayRef<int64_t> list,
328  Location loc,
329  const char *message) {
330  for (int64_t value : list) {
331  if (llvm::any_of(reference, [&](unsigned ref) {
332  return static_cast<int64_t>(ref) == value;
333  })) {
334  continue;
335  }
336  return emitSilenceableFailure(loc) << llvm::formatv(message, value);
337  }
339 }
341 //===----------------------------------------------------------------------===//
342 // MatchStructuredDimOp
343 //===----------------------------------------------------------------------===//
345 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
346  Operation *current, transform::TransformResults &results,
347  transform::TransformState &state) {
348  auto linalgOp = cast<linalg::LinalgOp>(current);
349  SmallVector<int64_t> dimensions;
350  DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
351  if (!diag.succeeded())
352  return diag;
354  // If asked to check for the kind of dimension, perform the check.
355  if (getParallel() || getReduction()) {
356  SmallVector<unsigned> reference;
357  if (getParallel())
358  linalgOp.getParallelDims(reference);
359  else if (getReduction())
360  linalgOp.getReductionDims(reference);
363  containsAll(reference, dimensions, getLoc(),
364  getParallel() ? "expects dimension #{0} to be parallel"
365  : "expects dimension #{0} to be reduction");
366  if (!diag.succeeded())
367  return diag;
368  }
370  // If not capturing, we are done here.
371  if (!getResult())
372  return diag;
374  SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
375  Builder builder(current);
376  SmallVector<Attribute> captured = llvm::to_vector(
377  llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
378  return builder.getI64IntegerAttr(ranges[dim]);
379  }));
380  results.setParams(cast<OpResult>(getResult()), captured);
382 }
384 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
385  linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
387  expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
388  getRawDimList(), op.getNumLoops(), dims);
389  if (diag.isSilenceableFailure()) {
390  diag.attachNote(op->getLoc())
391  << "while considering dimensions of this payload operation";
392  }
393  return diag;
394 }
397  if (getParallel() && getReduction()) {
398  return emitOpError() << "cannot request the same dimension to be both "
399  "parallel and reduction";
400  }
401  return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
402  getIsInverted(), getIsAll());
403 }
405 //===----------------------------------------------------------------------===//
406 // MatchStructuredElementalBitwidthOp
407 //===----------------------------------------------------------------------===//
410 transform::MatchStructuredElementalBitwidthOp::matchValue(
411  Value current, transform::TransformResults &results,
412  transform::TransformState &state) {
413  auto setupResult = [&](int64_t bitwidth) {
414  Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
415  results.setParams(cast<OpResult>(getResult()), {attr});
417  };
419  Type type = current.getType();
420  if (type.isIntOrFloat())
421  return setupResult(type.getIntOrFloatBitWidth());
423  if (auto shapedType = dyn_cast<ShapedType>(type)) {
424  if (shapedType.getElementType().isIntOrFloat())
425  return setupResult(shapedType.getElementTypeBitWidth());
426  }
427  return emitSilenceableError()
428  << "unsupported type for bitwidth extraction: " << type;
429 }
431 //===----------------------------------------------------------------------===//
432 // MatchStructuredInputOp
433 //===----------------------------------------------------------------------===//
435 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
436  Operation *current, transform::TransformResults &results,
437  transform::TransformState &state) {
438  auto linalgOp = cast<linalg::LinalgOp>(current);
439  SmallVector<int64_t> positions;
440  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
441  if (!diag.succeeded())
442  return diag;
444  SmallVector<MappedValue> operandMapping;
445  operandMapping.reserve(positions.size());
446  for (int64_t position : positions) {
447  AffineMap indexingMap =
448  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
449  if (getPermutation() && !indexingMap.isPermutation()) {
450  return emitSilenceableError() << "the indexing map for input #"
451  << position << " is not a permutation";
452  }
453  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
454  return emitSilenceableError()
455  << "the indexing map for input #" << position
456  << " is not a projected permutation";
457  }
459  // If capture not requested, skip it.
460  if (!getResult())
461  continue;
463  if (isa<AffineMapParamType>(getResult().getType())) {
464  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
465  continue;
466  }
468  Value operand = linalgOp.getDpsInputOperand(position)->get();
469  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
470  operandMapping.emplace_back(operand);
471  continue;
472  }
474  Operation *operandProducer = operand.getDefiningOp();
475  if (!operandProducer) {
476  return emitSilenceableError()
477  << "input #" << position << " is not produced by an operation";
478  }
479  operandMapping.emplace_back(operandProducer);
480  }
481  if (getResult())
482  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
484 }
486 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
487  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
489  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
490  op.getNumDpsInputs(), positions);
491  if (diag.isSilenceableFailure()) {
492  diag.attachNote(op->getLoc())
493  << "while considering DPS inputs of this payload operation";
494  }
495  return diag;
496 }
498 /// Verifies a matcher op for structured input or output, specifically the
499 /// attributes specifying the operand positions.
500 template <typename OpTy>
501 LogicalResult verifyStructuredOperandOp(OpTy op) {
502  if (op.getPermutation() && op.getProjectedPermutation()) {
503  return op.emitOpError()
504  << op.getPermutationAttrName() << " and "
505  << op.getProjectedPermutationAttrName() << " are mutually exclusive";
506  }
507  if (op.getRawPositionList().size() > 1 && op.getResult()) {
508  return op.emitOpError()
509  << "cannot bind multiple inputs/inits to the same value";
510  }
512  return success();
513 }
516  if (failed(verifyStructuredOperandOp(*this)))
517  return failure();
518  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
519  getIsInverted(), getIsAll());
520 }
522 //===----------------------------------------------------------------------===//
523 // MatchStructuredInitOp
524 //===----------------------------------------------------------------------===//
526 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
527  Operation *current, transform::TransformResults &results,
528  transform::TransformState &state) {
529  auto linalgOp = cast<linalg::LinalgOp>(current);
530  SmallVector<int64_t> positions;
531  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
532  if (!diag.succeeded())
533  return diag;
535  SmallVector<MappedValue> operandMapping;
536  operandMapping.reserve(positions.size());
537  for (int64_t position : positions) {
538  AffineMap indexingMap =
539  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
540  if (getPermutation() && !indexingMap.isPermutation()) {
541  return emitSilenceableError() << "the indexing map for output(init) #"
542  << position << " is not a permutation";
543  }
544  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
545  return emitSilenceableError() << "the indexing map for output(init) #"
546  << position << " is not a permutation";
547  }
549  // If capture not requested, skip it.
550  if (!getResult())
551  continue;
553  if (isa<AffineMapParamType>(getResult().getType())) {
554  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
555  continue;
556  }
558  Value operand = linalgOp.getDpsInitOperand(position)->get();
559  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
560  operandMapping.emplace_back(operand);
561  continue;
562  }
564  Operation *operandProducer = operand.getDefiningOp();
565  if (!operandProducer) {
566  return emitSilenceableError() << "output(init) #" << position
567  << " is not produced by an operation";
568  }
569  operandMapping.emplace_back(operandProducer);
570  }
571  if (getResult())
572  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
574 }
576 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
577  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
579  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
580  op.getNumDpsInits(), positions);
581  if (diag.isSilenceableFailure()) {
582  diag.attachNote(op->getLoc())
583  << "while considering DPS inits (outputs) of this payload operation";
584  }
585  return diag;
586 }
589  if (failed(verifyStructuredOperandOp(*this)))
590  return failure();
591  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
592  getIsInverted(), getIsAll());
593 }
595 //===----------------------------------------------------------------------===//
596 // MatchStructuredNumInputsOp
597 //===----------------------------------------------------------------------===//
600 transform::MatchStructuredNumInputsOp::matchOperation(
601  Operation *current, transform::TransformResults &results,
602  transform::TransformState &state) {
603  auto linalgOp = cast<linalg::LinalgOp>(current);
604  Attribute attr =
605  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
606  results.setParams(cast<OpResult>(getResult()), {attr});
608 }
610 //===----------------------------------------------------------------------===//
611 // MatchStructuredNumInitsOp
612 //===----------------------------------------------------------------------===//
615 transform::MatchStructuredNumInitsOp::matchOperation(
616  Operation *current, transform::TransformResults &results,
617  transform::TransformState &state) {
618  auto linalgOp = cast<linalg::LinalgOp>(current);
619  Attribute attr =
620  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
621  results.setParams(cast<OpResult>(getResult()), {attr});
623 }
625 //===----------------------------------------------------------------------===//
626 // MatchStructuredRankOp
627 //===----------------------------------------------------------------------===//
629 DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
630  Operation *current, transform::TransformResults &results,
631  transform::TransformState &state) {
632  auto linalgOp = cast<linalg::LinalgOp>(current);
633  int64_t numLoops = linalgOp.getNumLoops();
634  Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
635  results.setParams(cast<OpResult>(getRank()), {attr});
637 }
639 //===----------------------------------------------------------------------===//
640 // MatchStructuredResultOp
641 //===----------------------------------------------------------------------===//
643 DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
645  transform::TransformState &state) {
646  auto linalgOp = cast<linalg::LinalgOp>(op);
647  int64_t position;
648  DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
649  if (!diag.succeeded())
650  return diag;
652  Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
653  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
654  results.setValues(cast<OpResult>(getResult()), {result});
656  }
658  if (result.getUsers().empty()) {
659  return emitSilenceableError()
660  << "no users of the result #" << getPosition();
661  }
662  Operation *firstUser = *result.getUsers().begin();
663  if (getAny()) {
664  results.set(cast<OpResult>(getResult()), {firstUser});
666  }
667  if (getSingle()) {
668  if (!llvm::hasSingleElement(result.getUsers())) {
669  return emitSilenceableError()
670  << "more than one result user with single user requested";
671  }
672  results.set(cast<OpResult>(getResult()), {firstUser});
674  }
676  return emitDefiniteFailure() << "unknown sub-predicate";
677 }
680 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
681  int64_t &position) {
682  auto rawPosition = static_cast<int64_t>(getPosition());
683  position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
684  if (position >= op.getNumDpsInits() || position < 0) {
685  return emitSilenceableError()
686  << "position " << rawPosition
687  << " overflows the number of results(ints) of the payload operation";
688  }
690 }
693  if ((getAny() || getSingle()) ^
694  isa<TransformHandleTypeInterface>(getResult().getType())) {
695  return emitOpError() << "expects either the any/single keyword or the type "
696  "value handle result type";
697  }
698  if (getAny() && getSingle()) {
699  return emitOpError() << "'any' and 'single' are mutually exclusive";
700  }
701  return success();
702 }
704 //===----------------------------------------------------------------------===//
705 // MatchStructuredYieldOp
706 //===----------------------------------------------------------------------===//
708 void transform::MatchStructuredYieldOp::getEffects(
710  onlyReadsHandle(getHandlesMutable(), effects);
711  onlyReadsPayload(effects);
712 }
714 void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
715  OperationState &state) {
716  build(builder, state, ValueRange());
717 }
719 #define GET_OP_CLASSES
720 #include "mlir/Dialect/Linalg/TransformOps/"
