MLIR  19.0.0git
LinalgMatchOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/FormatVariadic.h"
20 
21 using namespace mlir;
22 
23 #define DEBUG_TYPE "linalg-transforms"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25 
26 //===----------------------------------------------------------------------===//
27 // StructuredMatchOp
28 //===----------------------------------------------------------------------===//
29 
30 DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
31  Operation *current, transform::TransformResults &results,
33  // First, check if the payload operation is a structured Linalg operation.
34  if (!isa<linalg::LinalgOp>(current)) {
35  if (getFailurePropagationMode().value_or(
36  FailurePropagationMode::Propagate) ==
37  FailurePropagationMode::Propagate) {
38  return emitSilenceableError() << "expected a Linalg op";
39  }
40  // If errors are suppressed, succeed and set all results to empty lists.
41  LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
42  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
44  }
45 
46  // Bind `current` to the block argument.
47  auto scope = state.make_region_scope(getBodyRegion());
48  if (failed(state.mapBlockArgument(getBody()->getArgument(0),
49  MappedValue(current)))) {
51  }
52 
53  for (Operation &nested : getBody()->without_terminator()) {
55  state.applyTransform(cast<TransformOpInterface>(nested));
56  if (diag.isDefiniteFailure())
57  return diag;
58  if (diag.succeeded())
59  continue;
60 
61  // If propagating errors, do this immediately.
62  assert(diag.isSilenceableFailure());
63  if (getFailurePropagationMode().value_or(
64  FailurePropagationMode::Propagate) ==
65  FailurePropagationMode::Propagate) {
66  return diag;
67  }
68 
69  // If suppressing errors, print the message into the debug stream before
70  // silencing it. Then set all results value that are already known.
71  // Results come from the terminator operands, which may be defined in the
72  // (single) block of this operation or above it. When they are defined
73  // above, they are known to be mapped at this point per SSA dominance.
74  // When they are defined in this block, we additionally check if we have
75  // already applied the operation that defines them. If not, the
76  // corresponding results will be set to empty lists.
77  LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
78  << "\n");
79  (void)diag.silence();
80  SmallVector<OpOperand *> undefinedOperands;
81  for (OpOperand &terminatorOperand :
82  getBody()->getTerminator()->getOpOperands()) {
83  Operation *definingOp = terminatorOperand.get().getDefiningOp();
84  if (!definingOp)
85  continue;
86  if (definingOp->getBlock() != getBody())
87  continue;
88  if (definingOp->isBeforeInBlock(&nested))
89  continue;
90 
91  undefinedOperands.push_back(&terminatorOperand);
92  }
93 
95  auto filtered = llvm::make_filter_range(
96  getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
97  return !llvm::is_contained(undefinedOperands, &opOperand);
98  });
99  SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
100  filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
101  detail::prepareValueMappings(mappings, definedOperands, state);
102  for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
103  results.setMappedValues(getResults()[operand.getOperandNumber()],
104  mapping);
105  }
106  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
108  }
109 
110  // Set the results.
111  detail::forwardTerminatorOperands(getBody(), state, results);
113 }
114 
115 void transform::MatchStructuredOp::getEffects(
117  onlyReadsHandle(getCurrent(), effects);
118  onlyReadsPayload(effects);
119  producesHandle(getOutputs(), effects);
120 }
121 
123  if (getBody()->getNumArguments() != 1)
124  return emitOpError() << "expected one body argument";
125  if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
126  return emitOpError() << "expected body argument to implement "
127  "TransformHandleTypeInterface";
128  }
129  for (Operation &nested : getBody()->without_terminator()) {
130  if (isa<MatchOpInterface>(nested))
131  continue;
133  emitOpError()
134  << "expects nested operations to implement MatchOpInterface";
135  diag.attachNote(nested.getLoc()) << "offending operation";
136  return diag;
137  }
138  return success();
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // StructuredOpPredicateOpTrait
143 //===----------------------------------------------------------------------===//
144 
146  Operation *op, Value structuredOpHandle) {
147  if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
148  return op->emitOpError() << "expects parent op to be '"
149  << MatchStructuredOp::getOperationName() << "'";
150  }
151 
152  // Bail out here, let the verifier of the parent complain.
153  Operation *parent = op->getParentOp();
154  if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
155  parent->getRegion(0).front().getNumArguments() < 1)
156  return success();
157 
158  if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
159  return op->emitOpError()
160  << "expected predicate to apply to the surrounding structured op";
161  }
162  return success();
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // MatchStructuredBodyOp
167 //===----------------------------------------------------------------------===//
168 
169 DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
170  Operation *current, transform::TransformResults &results,
171  transform::TransformState &state) {
172  auto linalgOp = cast<linalg::LinalgOp>(current);
173  if (std::optional<uint64_t> position = getReductionPosition()) {
174  SmallVector<Operation *> combinerOps;
175  if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
176  combinerOps)) {
177  return emitSilenceableError() << "could not match reduction";
178  }
179  if (combinerOps.size() != 1) {
180  return emitSilenceableError() << "reduction combiner is not a single op";
181  }
183  }
184  if (getPassthrough()) {
185  Block &body = linalgOp->getRegion(0).front();
186  if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
187  return emitSilenceableError() << "not a passthrough";
188  }
190  }
191  if (getElementwise()) {
192  if (!isElementwise(linalgOp))
193  return emitSilenceableError() << "not elementwise";
195  }
196  if (std::optional<ArrayAttr> contractionOps = getContraction()) {
197  Block &body = linalgOp->getRegion(0).front();
198  std::string message;
199  llvm::raw_string_ostream os(message);
200  bool result = linalg::detail::isContractionBody(
201  body,
202  [&](Operation *elem, Operation *red) {
203  return elem->getName().getStringRef() ==
204  (*contractionOps)[0].cast<StringAttr>().getValue() &&
205  red->getName().getStringRef() ==
206  (*contractionOps)[1].cast<StringAttr>().getValue();
207  },
208  os);
209  if (result)
211  return emitSilenceableError() << "contraction: " << os.str();
212  }
213  return emitDefiniteFailure() << "unknown body condition";
214 }
215 
217  int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
218  getElementwise() + getContraction().has_value();
219 
220  if (numOptions > 1) {
221  std::string attributeNames;
222  llvm::raw_string_ostream os(attributeNames);
223  llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
224  getPassthroughAttrName(),
225  getElementwiseAttrName(),
226  getContractionAttrName()},
227  os);
228  return emitOpError() << "only one of {" << os.str() << "} is allowed";
229  }
230 
231  if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
232  if (contractionAttr->size() != 2) {
233  return emitOpError() << "expects " << getContractionAttrName()
234  << " to contain two elements";
235  }
236  }
237  return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // MatchStructuredClassifyContractionDimsOp
242 //===----------------------------------------------------------------------===//
243 
245 transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
246  Operation *current, transform::TransformResults &results,
247  transform::TransformState &state) {
249  linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
250  if (failed(contractionDims))
251  return emitSilenceableError() << "could not infer contraction dimensions";
252 
253  MLIRContext *context = current->getContext();
254  Builder builder(context);
255  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
256  return llvm::to_vector(
257  llvm::map_range(values, [&](unsigned value) -> Attribute {
258  return builder.getI64IntegerAttr(value);
259  }));
260  };
261  results.setParams(getBatch().cast<OpResult>(),
262  makeI64Attrs(contractionDims->batch));
263  results.setParams(getM().cast<OpResult>(), makeI64Attrs(contractionDims->m));
264  results.setParams(getN().cast<OpResult>(), makeI64Attrs(contractionDims->n));
265  results.setParams(getK().cast<OpResult>(), makeI64Attrs(contractionDims->k));
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // MatchStructuredClassifyConvolutionDimsOp
271 //===----------------------------------------------------------------------===//
272 
274 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
275  Operation *current, transform::TransformResults &results,
276  transform::TransformState &state) {
278  linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
279  if (failed(convolutionDims))
280  return emitSilenceableError() << "could not infer convolution dimensions";
281 
282  MLIRContext *context = current->getContext();
283  Builder builder(context);
284  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
285  return llvm::to_vector(
286  llvm::map_range(values, [&](unsigned value) -> Attribute {
287  return builder.getI64IntegerAttr(value);
288  }));
289  };
290  results.setParams(getBatch().cast<OpResult>(),
291  makeI64Attrs(convolutionDims->batch));
292  results.setParams(getOutputImage().cast<OpResult>(),
293  makeI64Attrs(convolutionDims->outputImage));
294  results.setParams(getOutputChannel().cast<OpResult>(),
295  makeI64Attrs(convolutionDims->outputChannel));
296  results.setParams(getFilterLoop().cast<OpResult>(),
297  makeI64Attrs(convolutionDims->filterLoop));
298  results.setParams(getInputChannel().cast<OpResult>(),
299  makeI64Attrs(convolutionDims->inputChannel));
300  results.setParams(getDepth().cast<OpResult>(),
301  makeI64Attrs(convolutionDims->depth));
302 
303  auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
304  return llvm::to_vector(
305  llvm::map_range(values, [&](int64_t value) -> Attribute {
306  return builder.getI64IntegerAttr(value);
307  }));
308  };
309  results.setParams(getStrides().cast<OpResult>(),
310  makeI64AttrsFromI64(convolutionDims->strides));
311  results.setParams(getDilations().cast<OpResult>(),
312  makeI64AttrsFromI64(convolutionDims->dilations));
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // Utilities for structured match predicates.
318 //===----------------------------------------------------------------------===//
319 
320 /// Checks if all values from `list` are also contained in `reference`. Returns
321 /// a silenceable error with the given message at the given location when it is
322 /// not the case. The error message must contain the "{0}" placeholder that
323 /// will be substituted with the value from `list` that is not contained in
324 /// `reference`.
326  ArrayRef<int64_t> list,
327  Location loc,
328  const char *message) {
329  for (int64_t value : list) {
330  if (llvm::any_of(reference, [&](unsigned ref) {
331  return static_cast<int64_t>(ref) == value;
332  })) {
333  continue;
334  }
335  return emitSilenceableFailure(loc) << llvm::formatv(message, value);
336  }
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // MatchStructuredDimOp
342 //===----------------------------------------------------------------------===//
343 
344 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
345  Operation *current, transform::TransformResults &results,
346  transform::TransformState &state) {
347  auto linalgOp = cast<linalg::LinalgOp>(current);
348  SmallVector<int64_t> dimensions;
349  DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
350  if (!diag.succeeded())
351  return diag;
352 
353  // If asked to check for the kind of dimension, perform the check.
354  if (getParallel() || getReduction()) {
355  SmallVector<unsigned> reference;
356  if (getParallel())
357  linalgOp.getParallelDims(reference);
358  else if (getReduction())
359  linalgOp.getReductionDims(reference);
360 
362  containsAll(reference, dimensions, getLoc(),
363  getParallel() ? "expects dimension #{0} to be parallel"
364  : "expects dimension #{0} to be reduction");
365  if (!diag.succeeded())
366  return diag;
367  }
368 
369  // If not capturing, we are done here.
370  if (!getResult())
371  return diag;
372 
373  SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
374  Builder builder(current);
375  SmallVector<Attribute> captured = llvm::to_vector(
376  llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
377  return builder.getI64IntegerAttr(ranges[dim]);
378  }));
379  results.setParams(cast<OpResult>(getResult()), captured);
381 }
382 
383 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
384  linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
386  expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
387  getRawDimList(), op.getNumLoops(), dims);
388  if (diag.isSilenceableFailure()) {
389  diag.attachNote(op->getLoc())
390  << "while considering dimensions of this payload operation";
391  }
392  return diag;
393 }
394 
396  if (getParallel() && getReduction()) {
397  return emitOpError() << "cannot request the same dimension to be both "
398  "parallel and reduction";
399  }
400  return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
401  getIsInverted(), getIsAll());
402 }
403 
404 //===----------------------------------------------------------------------===//
405 // MatchStructuredElementalBitwidthOp
406 //===----------------------------------------------------------------------===//
407 
409 transform::MatchStructuredElementalBitwidthOp::matchValue(
410  Value current, transform::TransformResults &results,
411  transform::TransformState &state) {
412  auto setupResult = [&](int64_t bitwidth) {
413  Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
414  results.setParams(cast<OpResult>(getResult()), {attr});
416  };
417 
418  Type type = current.getType();
419  if (type.isIntOrFloat())
420  return setupResult(type.getIntOrFloatBitWidth());
421 
422  if (auto shapedType = dyn_cast<ShapedType>(type)) {
423  if (shapedType.getElementType().isIntOrFloat())
424  return setupResult(shapedType.getElementTypeBitWidth());
425  }
426  return emitSilenceableError()
427  << "unsupported type for bitwidth extraction: " << type;
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // MatchStructuredInputOp
432 //===----------------------------------------------------------------------===//
433 
434 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
435  Operation *current, transform::TransformResults &results,
436  transform::TransformState &state) {
437  auto linalgOp = cast<linalg::LinalgOp>(current);
438  SmallVector<int64_t> positions;
439  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
440  if (!diag.succeeded())
441  return diag;
442 
443  SmallVector<MappedValue> operandMapping;
444  operandMapping.reserve(positions.size());
445  for (int64_t position : positions) {
446  AffineMap indexingMap =
447  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
448  if (getPermutation() && !indexingMap.isPermutation()) {
449  return emitSilenceableError() << "the indexing map for input #"
450  << position << " is not a permutation";
451  }
452  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
453  return emitSilenceableError()
454  << "the indexing map for input #" << position
455  << " is not a projected permutation";
456  }
457 
458  // If capture not requested, skip it.
459  if (!getResult())
460  continue;
461 
462  if (isa<AffineMapParamType>(getResult().getType())) {
463  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
464  continue;
465  }
466 
467  Value operand = linalgOp.getDpsInputOperand(position)->get();
468  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
469  operandMapping.emplace_back(operand);
470  continue;
471  }
472 
473  Operation *operandProducer = operand.getDefiningOp();
474  if (!operandProducer) {
475  return emitSilenceableError()
476  << "input #" << position << " is not produced by an operation";
477  }
478  operandMapping.emplace_back(operandProducer);
479  }
480  if (getResult())
481  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
483 }
484 
485 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
486  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
488  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
489  op.getNumDpsInputs(), positions);
490  if (diag.isSilenceableFailure()) {
491  diag.attachNote(op->getLoc())
492  << "while considering DPS inputs of this payload operation";
493  }
494  return diag;
495 }
496 
497 /// Verifies a matcher op for structured input or output, specifically the
498 /// attributes specifying the operand positions.
499 template <typename OpTy>
501  if (op.getPermutation() && op.getProjectedPermutation()) {
502  return op.emitOpError()
503  << op.getPermutationAttrName() << " and "
504  << op.getProjectedPermutationAttrName() << " are mutually exclusive";
505  }
506  if (op.getRawPositionList().size() > 1 && op.getResult()) {
507  return op.emitOpError()
508  << "cannot bind multiple inputs/inits to the same value";
509  }
510 
511  return success();
512 }
513 
515  if (failed(verifyStructuredOperandOp(*this)))
516  return failure();
517  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
518  getIsInverted(), getIsAll());
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // MatchStructuredInitOp
523 //===----------------------------------------------------------------------===//
524 
525 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
526  Operation *current, transform::TransformResults &results,
527  transform::TransformState &state) {
528  auto linalgOp = cast<linalg::LinalgOp>(current);
529  SmallVector<int64_t> positions;
530  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
531  if (!diag.succeeded())
532  return diag;
533 
534  SmallVector<MappedValue> operandMapping;
535  operandMapping.reserve(positions.size());
536  for (int64_t position : positions) {
537  AffineMap indexingMap =
538  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
539  if (getPermutation() && !indexingMap.isPermutation()) {
540  return emitSilenceableError() << "the indexing map for output(init) #"
541  << position << " is not a permutation";
542  }
543  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
544  return emitSilenceableError() << "the indexing map for output(init) #"
545  << position << " is not a permutation";
546  }
547 
548  // If capture not requested, skip it.
549  if (!getResult())
550  continue;
551 
552  if (isa<AffineMapParamType>(getResult().getType())) {
553  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
554  continue;
555  }
556 
557  Value operand = linalgOp.getDpsInitOperand(position)->get();
558  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
559  operandMapping.emplace_back(operand);
560  continue;
561  }
562 
563  Operation *operandProducer = operand.getDefiningOp();
564  if (!operandProducer) {
565  return emitSilenceableError() << "output(init) #" << position
566  << " is not produced by an operation";
567  }
568  operandMapping.emplace_back(operandProducer);
569  }
570  if (getResult())
571  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
573 }
574 
575 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
576  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
578  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
579  op.getNumDpsInits(), positions);
580  if (diag.isSilenceableFailure()) {
581  diag.attachNote(op->getLoc())
582  << "while considering DPS inits (outputs) of this payload operation";
583  }
584  return diag;
585 }
586 
588  if (failed(verifyStructuredOperandOp(*this)))
589  return failure();
590  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
591  getIsInverted(), getIsAll());
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // MatchStructuredNumInputsOp
596 //===----------------------------------------------------------------------===//
597 
599 transform::MatchStructuredNumInputsOp::matchOperation(
600  Operation *current, transform::TransformResults &results,
601  transform::TransformState &state) {
602  auto linalgOp = cast<linalg::LinalgOp>(current);
603  Attribute attr =
604  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
605  results.setParams(cast<OpResult>(getResult()), {attr});
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // MatchStructuredNumInitsOp
611 //===----------------------------------------------------------------------===//
612 
614 transform::MatchStructuredNumInitsOp::matchOperation(
615  Operation *current, transform::TransformResults &results,
616  transform::TransformState &state) {
617  auto linalgOp = cast<linalg::LinalgOp>(current);
618  Attribute attr =
619  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
620  results.setParams(cast<OpResult>(getResult()), {attr});
622 }
623 
624 //===----------------------------------------------------------------------===//
625 // MatchStructuredRankOp
626 //===----------------------------------------------------------------------===//
627 
628 DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
629  Operation *current, transform::TransformResults &results,
630  transform::TransformState &state) {
631  auto linalgOp = cast<linalg::LinalgOp>(current);
632  int64_t numLoops = linalgOp.getNumLoops();
633  Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
634  results.setParams(cast<OpResult>(getRank()), {attr});
636 }
637 
638 //===----------------------------------------------------------------------===//
639 // MatchStructuredResultOp
640 //===----------------------------------------------------------------------===//
641 
642 DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
644  transform::TransformState &state) {
645  auto linalgOp = cast<linalg::LinalgOp>(op);
646  int64_t position;
647  DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
648  if (!diag.succeeded())
649  return diag;
650 
651  Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
652  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
653  results.setValues(cast<OpResult>(getResult()), {result});
655  }
656 
657  if (result.getUsers().empty()) {
658  return emitSilenceableError()
659  << "no users of the result #" << getPosition();
660  }
661  Operation *firstUser = *result.getUsers().begin();
662  if (getAny()) {
663  results.set(cast<OpResult>(getResult()), {firstUser});
665  }
666  if (getSingle()) {
667  if (!llvm::hasSingleElement(result.getUsers())) {
668  return emitSilenceableError()
669  << "more than one result user with single user requested";
670  }
671  results.set(cast<OpResult>(getResult()), {firstUser});
673  }
674 
675  return emitDefiniteFailure() << "unknown sub-predicate";
676 }
677 
679 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
680  int64_t &position) {
681  auto rawPosition = static_cast<int64_t>(getPosition());
682  position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
683  if (position >= op.getNumDpsInits() || position < 0) {
684  return emitSilenceableError()
685  << "position " << rawPosition
686  << " overflows the number of results(ints) of the payload operation";
687  }
689 }
690 
692  if ((getAny() || getSingle()) ^
693  isa<TransformHandleTypeInterface>(getResult().getType())) {
694  return emitOpError() << "expects either the any/single keyword or the type "
695  "value handle result type";
696  }
697  if (getAny() && getSingle()) {
698  return emitOpError() << "'any' and 'single' are mutually exclusive";
699  }
700  return success();
701 }
702 
703 //===----------------------------------------------------------------------===//
704 // MatchStructuredYieldOp
705 //===----------------------------------------------------------------------===//
706 
707 void transform::MatchStructuredYieldOp::getEffects(
709  onlyReadsHandle(getHandles(), effects);
710  onlyReadsPayload(effects);
711 }
712 
713 void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
714  OperationState &state) {
715  build(builder, state, ValueRange());
716 }
717 
718 #define GET_OP_CLASSES
719 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
#define DBGS()
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:581
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:611
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
Operation & front()
Definition: Block.h:150
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
This class represents an operand of an operation.
Definition: Value.h:263
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:117
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:128
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_range getUsers() const
Definition: Value.h:224
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
The state maintained across applications of various ops implementing the TransformOpInterface.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
uint64_t getN(LevelType lt)
Definition: Enums.h:434
uint64_t getM(LevelType lt)
Definition: Enums.h:435
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op, Value structuredOpHandle)
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.