MLIR  22.0.0git
LinalgMatchOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "llvm/Support/DebugLog.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/InterleavedRange.h"
21 
22 using namespace mlir;
23 
24 #define DEBUG_TYPE "linalg-transforms"
25 
26 //===----------------------------------------------------------------------===//
27 // StructuredMatchOp
28 //===----------------------------------------------------------------------===//
29 
30 DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
31  Operation *current, transform::TransformResults &results,
33  // First, check if the payload operation is a structured Linalg operation.
34  if (!isa<linalg::LinalgOp>(current)) {
35  if (getFailurePropagationMode().value_or(
36  FailurePropagationMode::Propagate) ==
37  FailurePropagationMode::Propagate) {
38  return emitSilenceableError() << "expected a Linalg op";
39  }
40  // If errors are suppressed, succeed and set all results to empty lists.
41  LDBG() << "optional nested matcher expected a Linalg op";
42  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
44  }
45 
46  // Bind `current` to the block argument.
47  auto scope = state.make_region_scope(getBodyRegion());
48  if (failed(state.mapBlockArgument(getBody()->getArgument(0),
49  MappedValue(current)))) {
51  }
52 
53  for (Operation &nested : getBody()->without_terminator()) {
55  state.applyTransform(cast<TransformOpInterface>(nested));
56  if (diag.isDefiniteFailure())
57  return diag;
58  if (diag.succeeded())
59  continue;
60 
61  // If propagating errors, do this immediately.
62  assert(diag.isSilenceableFailure());
63  if (getFailurePropagationMode().value_or(
64  FailurePropagationMode::Propagate) ==
65  FailurePropagationMode::Propagate) {
66  return diag;
67  }
68 
69  // If suppressing errors, print the message into the debug stream before
70  // silencing it. Then set all results value that are already known.
71  // Results come from the terminator operands, which may be defined in the
72  // (single) block of this operation or above it. When they are defined
73  // above, they are known to be mapped at this point per SSA dominance.
74  // When they are defined in this block, we additionally check if we have
75  // already applied the operation that defines them. If not, the
76  // corresponding results will be set to empty lists.
77  LDBG() << "optional nested matcher failed: " << diag.getMessage();
78  (void)diag.silence();
79  SmallVector<OpOperand *> undefinedOperands;
80  for (OpOperand &terminatorOperand :
81  getBody()->getTerminator()->getOpOperands()) {
82  Operation *definingOp = terminatorOperand.get().getDefiningOp();
83  if (!definingOp)
84  continue;
85  if (definingOp->getBlock() != getBody())
86  continue;
87  if (definingOp->isBeforeInBlock(&nested))
88  continue;
89 
90  undefinedOperands.push_back(&terminatorOperand);
91  }
92 
94  auto filtered = llvm::make_filter_range(
95  getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
96  return !llvm::is_contained(undefinedOperands, &opOperand);
97  });
98  SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
99  filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
100  detail::prepareValueMappings(mappings, definedOperands, state);
101  for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
102  results.setMappedValues(getResults()[operand.getOperandNumber()],
103  mapping);
104  }
105  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
107  }
108 
109  // Set the results.
110  detail::forwardTerminatorOperands(getBody(), state, results);
112 }
113 
114 void transform::MatchStructuredOp::getEffects(
116  onlyReadsHandle(getCurrentMutable(), effects);
117  onlyReadsPayload(effects);
118  producesHandle(getOperation()->getOpResults(), effects);
119 }
120 
121 LogicalResult transform::MatchStructuredOp::verify() {
122  if (getBody()->getNumArguments() != 1)
123  return emitOpError() << "expected one body argument";
124  if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
125  return emitOpError() << "expected body argument to implement "
126  "TransformHandleTypeInterface";
127  }
128  for (Operation &nested : getBody()->without_terminator()) {
129  if (isa<MatchOpInterface>(nested))
130  continue;
132  emitOpError()
133  << "expects nested operations to implement MatchOpInterface";
134  diag.attachNote(nested.getLoc()) << "offending operation";
135  return diag;
136  }
137  return success();
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // StructuredOpPredicateOpTrait
142 //===----------------------------------------------------------------------===//
143 
145  Operation *op, Value structuredOpHandle) {
146  if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
147  return op->emitOpError() << "expects parent op to be '"
148  << MatchStructuredOp::getOperationName() << "'";
149  }
150 
151  // Bail out here, let the verifier of the parent complain.
152  Operation *parent = op->getParentOp();
153  if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
154  parent->getRegion(0).front().getNumArguments() < 1)
155  return success();
156 
157  if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
158  return op->emitOpError()
159  << "expected predicate to apply to the surrounding structured op";
160  }
161  return success();
162 }
163 
164 //===----------------------------------------------------------------------===//
165 // MatchStructuredBodyOp
166 //===----------------------------------------------------------------------===//
167 
168 DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
169  Operation *current, transform::TransformResults &results,
170  transform::TransformState &state) {
171  auto linalgOp = cast<linalg::LinalgOp>(current);
172  if (std::optional<uint64_t> position = getReductionPosition()) {
173  SmallVector<Operation *> combinerOps;
174  if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
175  combinerOps)) {
176  return emitSilenceableError() << "could not match reduction";
177  }
178  if (combinerOps.size() != 1) {
179  return emitSilenceableError() << "reduction combiner is not a single op";
180  }
182  }
183  if (getPassthrough()) {
184  Block &body = linalgOp->getRegion(0).front();
185  if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
186  return emitSilenceableError() << "not a passthrough";
187  }
189  }
190  if (getElementwise()) {
191  if (!isElementwise(linalgOp))
192  return emitSilenceableError() << "not elementwise";
194  }
195  if (std::optional<ArrayAttr> contractionOps = getContraction()) {
196  Block &body = linalgOp->getRegion(0).front();
197  std::string message;
198  llvm::raw_string_ostream os(message);
199  bool result = linalg::detail::isContractionBody(
200  body,
201  [&](Operation *elem, Operation *red) {
202  return elem->getName().getStringRef() ==
203  cast<StringAttr>((*contractionOps)[0]).getValue() &&
204  red->getName().getStringRef() ==
205  cast<StringAttr>((*contractionOps)[1]).getValue();
206  },
207  os);
208  if (result)
210  return emitSilenceableError() << "contraction: " << message;
211  }
212  return emitDefiniteFailure() << "unknown body condition";
213 }
214 
216  int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
217  getElementwise() + getContraction().has_value();
218 
219  if (numOptions > 1) {
220  StringAttr attributeNames[] = {
221  getReductionPositionAttrName(), getPassthroughAttrName(),
222  getElementwiseAttrName(), getContractionAttrName()};
223  return emitOpError() << "only one of {" << llvm::interleaved(attributeNames)
224  << "} is allowed";
225  }
226 
227  if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
228  if (contractionAttr->size() != 2) {
229  return emitOpError() << "expects " << getContractionAttrName()
230  << " to contain two elements";
231  }
232  }
233  return success();
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // MatchStructuredClassifyContractionDimsOp
238 //===----------------------------------------------------------------------===//
239 
241 transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
242  Operation *current, transform::TransformResults &results,
243  transform::TransformState &state) {
244  FailureOr<linalg::ContractionDimensions> contractionDims =
245  linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
246  if (failed(contractionDims))
247  return emitSilenceableError() << "could not infer contraction dimensions";
248 
249  MLIRContext *context = current->getContext();
250  Builder builder(context);
251  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
252  return llvm::to_vector(
253  llvm::map_range(values, [&](unsigned value) -> Attribute {
254  return builder.getI64IntegerAttr(value);
255  }));
256  };
257  results.setParams(cast<OpResult>(getBatch()),
258  makeI64Attrs(contractionDims->batch));
259  results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
260  results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
261  results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // MatchStructuredClassifyConvolutionDimsOp
267 //===----------------------------------------------------------------------===//
268 
270 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
271  Operation *current, transform::TransformResults &results,
272  transform::TransformState &state) {
273  FailureOr<linalg::ConvolutionDimensions> convolutionDims =
274  linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
275  if (failed(convolutionDims))
276  return emitSilenceableError() << "could not infer convolution dimensions";
277 
278  MLIRContext *context = current->getContext();
279  Builder builder(context);
280  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
281  return llvm::to_vector(
282  llvm::map_range(values, [&](unsigned value) -> Attribute {
283  return builder.getI64IntegerAttr(value);
284  }));
285  };
286  results.setParams(cast<OpResult>(getBatch()),
287  makeI64Attrs(convolutionDims->batch));
288  results.setParams(cast<OpResult>(getOutputImage()),
289  makeI64Attrs(convolutionDims->outputImage));
290  results.setParams(cast<OpResult>(getOutputChannel()),
291  makeI64Attrs(convolutionDims->outputChannel));
292  results.setParams(cast<OpResult>(getFilterLoop()),
293  makeI64Attrs(convolutionDims->filterLoop));
294  results.setParams(cast<OpResult>(getInputChannel()),
295  makeI64Attrs(convolutionDims->inputChannel));
296  results.setParams(cast<OpResult>(getDepth()),
297  makeI64Attrs(convolutionDims->depth));
298 
299  auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
300  return llvm::to_vector(
301  llvm::map_range(values, [&](int64_t value) -> Attribute {
302  return builder.getI64IntegerAttr(value);
303  }));
304  };
305  results.setParams(cast<OpResult>(getStrides()),
306  makeI64AttrsFromI64(convolutionDims->strides));
307  results.setParams(cast<OpResult>(getDilations()),
308  makeI64AttrsFromI64(convolutionDims->dilations));
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // Utilities for structured match predicates.
314 //===----------------------------------------------------------------------===//
315 
316 /// Checks if all values from `list` are also contained in `reference`. Returns
317 /// a silenceable error with the given message at the given location when it is
318 /// not the case. The error message must contain the "{0}" placeholder that
319 /// will be substituted with the value from `list` that is not contained in
320 /// `reference`.
322  ArrayRef<int64_t> list,
323  Location loc,
324  const char *message) {
325  for (int64_t value : list) {
326  if (llvm::any_of(reference, [&](unsigned ref) {
327  return static_cast<int64_t>(ref) == value;
328  })) {
329  continue;
330  }
331  return emitSilenceableFailure(loc) << llvm::formatv(message, value);
332  }
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // MatchStructuredDimOp
338 //===----------------------------------------------------------------------===//
339 
340 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
341  Operation *current, transform::TransformResults &results,
342  transform::TransformState &state) {
343  auto linalgOp = cast<linalg::LinalgOp>(current);
344  SmallVector<int64_t> dimensions;
345  DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
346  if (!diag.succeeded())
347  return diag;
348 
349  // If asked to check for the kind of dimension, perform the check.
350  if (getParallel() || getReduction()) {
351  SmallVector<unsigned> reference;
352  if (getParallel())
353  linalgOp.getParallelDims(reference);
354  else if (getReduction())
355  linalgOp.getReductionDims(reference);
356 
358  containsAll(reference, dimensions, getLoc(),
359  getParallel() ? "expects dimension #{0} to be parallel"
360  : "expects dimension #{0} to be reduction");
361  if (!diag.succeeded())
362  return diag;
363  }
364 
365  // If not capturing, we are done here.
366  if (!getResult())
367  return diag;
368 
369  SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
370  Builder builder(current);
371  SmallVector<Attribute> captured = llvm::to_vector(
372  llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
373  return builder.getI64IntegerAttr(ranges[dim]);
374  }));
375  results.setParams(cast<OpResult>(getResult()), captured);
377 }
378 
379 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
380  linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
382  expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
383  getRawDimList(), op.getNumLoops(), dims);
384  if (diag.isSilenceableFailure()) {
385  diag.attachNote(op->getLoc())
386  << "while considering dimensions of this payload operation";
387  }
388  return diag;
389 }
390 
392  if (getParallel() && getReduction()) {
393  return emitOpError() << "cannot request the same dimension to be both "
394  "parallel and reduction";
395  }
396  return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
397  getIsInverted(), getIsAll());
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // MatchStructuredElementalBitwidthOp
402 //===----------------------------------------------------------------------===//
403 
405 transform::MatchStructuredElementalBitwidthOp::matchValue(
406  Value current, transform::TransformResults &results,
407  transform::TransformState &state) {
408  auto setupResult = [&](int64_t bitwidth) {
409  Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
410  results.setParams(cast<OpResult>(getResult()), {attr});
412  };
413 
414  Type type = current.getType();
415  if (type.isIntOrFloat())
416  return setupResult(type.getIntOrFloatBitWidth());
417 
418  if (auto shapedType = dyn_cast<ShapedType>(type)) {
419  if (shapedType.getElementType().isIntOrFloat())
420  return setupResult(shapedType.getElementTypeBitWidth());
421  }
422  return emitSilenceableError()
423  << "unsupported type for bitwidth extraction: " << type;
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // MatchStructuredInputOp
428 //===----------------------------------------------------------------------===//
429 
430 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
431  Operation *current, transform::TransformResults &results,
432  transform::TransformState &state) {
433  auto linalgOp = cast<linalg::LinalgOp>(current);
434  SmallVector<int64_t> positions;
435  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
436  if (!diag.succeeded())
437  return diag;
438 
439  SmallVector<MappedValue> operandMapping;
440  operandMapping.reserve(positions.size());
441  for (int64_t position : positions) {
442  AffineMap indexingMap =
443  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
444  if (getPermutation() && !indexingMap.isPermutation()) {
445  return emitSilenceableError() << "the indexing map for input #"
446  << position << " is not a permutation";
447  }
448  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
449  return emitSilenceableError()
450  << "the indexing map for input #" << position
451  << " is not a projected permutation";
452  }
453 
454  // If capture not requested, skip it.
455  if (!getResult())
456  continue;
457 
458  if (isa<AffineMapParamType>(getResult().getType())) {
459  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
460  continue;
461  }
462 
463  Value operand = linalgOp.getDpsInputOperand(position)->get();
464  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
465  operandMapping.emplace_back(operand);
466  continue;
467  }
468 
469  Operation *operandProducer = operand.getDefiningOp();
470  if (!operandProducer) {
471  return emitSilenceableError()
472  << "input #" << position << " is not produced by an operation";
473  }
474  operandMapping.emplace_back(operandProducer);
475  }
476  if (getResult())
477  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
479 }
480 
481 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
482  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
484  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
485  op.getNumDpsInputs(), positions);
486  if (diag.isSilenceableFailure()) {
487  diag.attachNote(op->getLoc())
488  << "while considering DPS inputs of this payload operation";
489  }
490  return diag;
491 }
492 
493 /// Verifies a matcher op for structured input or output, specifically the
494 /// attributes specifying the operand positions.
495 template <typename OpTy>
496 LogicalResult verifyStructuredOperandOp(OpTy op) {
497  if (op.getPermutation() && op.getProjectedPermutation()) {
498  return op.emitOpError()
499  << op.getPermutationAttrName() << " and "
500  << op.getProjectedPermutationAttrName() << " are mutually exclusive";
501  }
502  if (op.getRawPositionList().size() > 1 && op.getResult()) {
503  return op.emitOpError()
504  << "cannot bind multiple inputs/inits to the same value";
505  }
506 
507  return success();
508 }
509 
511  if (failed(verifyStructuredOperandOp(*this)))
512  return failure();
513  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
514  getIsInverted(), getIsAll());
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // MatchStructuredInitOp
519 //===----------------------------------------------------------------------===//
520 
521 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
522  Operation *current, transform::TransformResults &results,
523  transform::TransformState &state) {
524  auto linalgOp = cast<linalg::LinalgOp>(current);
525  SmallVector<int64_t> positions;
526  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
527  if (!diag.succeeded())
528  return diag;
529 
530  SmallVector<MappedValue> operandMapping;
531  operandMapping.reserve(positions.size());
532  for (int64_t position : positions) {
533  AffineMap indexingMap =
534  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
535  if (getPermutation() && !indexingMap.isPermutation()) {
536  return emitSilenceableError() << "the indexing map for output(init) #"
537  << position << " is not a permutation";
538  }
539  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
540  return emitSilenceableError() << "the indexing map for output(init) #"
541  << position << " is not a permutation";
542  }
543 
544  // If capture not requested, skip it.
545  if (!getResult())
546  continue;
547 
548  if (isa<AffineMapParamType>(getResult().getType())) {
549  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
550  continue;
551  }
552 
553  Value operand = linalgOp.getDpsInitOperand(position)->get();
554  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
555  operandMapping.emplace_back(operand);
556  continue;
557  }
558 
559  Operation *operandProducer = operand.getDefiningOp();
560  if (!operandProducer) {
561  return emitSilenceableError() << "output(init) #" << position
562  << " is not produced by an operation";
563  }
564  operandMapping.emplace_back(operandProducer);
565  }
566  if (getResult())
567  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
569 }
570 
571 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
572  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
574  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
575  op.getNumDpsInits(), positions);
576  if (diag.isSilenceableFailure()) {
577  diag.attachNote(op->getLoc())
578  << "while considering DPS inits (outputs) of this payload operation";
579  }
580  return diag;
581 }
582 
584  if (failed(verifyStructuredOperandOp(*this)))
585  return failure();
586  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
587  getIsInverted(), getIsAll());
588 }
589 
590 //===----------------------------------------------------------------------===//
591 // MatchStructuredNumInputsOp
592 //===----------------------------------------------------------------------===//
593 
595 transform::MatchStructuredNumInputsOp::matchOperation(
596  Operation *current, transform::TransformResults &results,
597  transform::TransformState &state) {
598  auto linalgOp = cast<linalg::LinalgOp>(current);
599  Attribute attr =
600  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
601  results.setParams(cast<OpResult>(getResult()), {attr});
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // MatchStructuredNumInitsOp
607 //===----------------------------------------------------------------------===//
608 
610 transform::MatchStructuredNumInitsOp::matchOperation(
611  Operation *current, transform::TransformResults &results,
612  transform::TransformState &state) {
613  auto linalgOp = cast<linalg::LinalgOp>(current);
614  Attribute attr =
615  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
616  results.setParams(cast<OpResult>(getResult()), {attr});
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // MatchStructuredRankOp
622 //===----------------------------------------------------------------------===//
623 
624 DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
625  Operation *current, transform::TransformResults &results,
626  transform::TransformState &state) {
627  auto linalgOp = cast<linalg::LinalgOp>(current);
628  int64_t numLoops = linalgOp.getNumLoops();
629  Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
630  results.setParams(cast<OpResult>(getRank()), {attr});
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // MatchStructuredResultOp
636 //===----------------------------------------------------------------------===//
637 
638 DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
640  transform::TransformState &state) {
641  auto linalgOp = cast<linalg::LinalgOp>(op);
642  int64_t position;
643  DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
644  if (!diag.succeeded())
645  return diag;
646 
647  Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
648  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
649  results.setValues(cast<OpResult>(getResult()), {result});
651  }
652 
653  if (result.getUsers().empty()) {
654  return emitSilenceableError()
655  << "no users of the result #" << getPosition();
656  }
657  Operation *firstUser = *result.getUsers().begin();
658  if (getAny()) {
659  results.set(cast<OpResult>(getResult()), {firstUser});
661  }
662  if (getSingle()) {
663  if (!llvm::hasSingleElement(result.getUsers())) {
664  return emitSilenceableError()
665  << "more than one result user with single user requested";
666  }
667  results.set(cast<OpResult>(getResult()), {firstUser});
669  }
670 
671  return emitDefiniteFailure() << "unknown sub-predicate";
672 }
673 
675 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
676  int64_t &position) {
677  auto rawPosition = static_cast<int64_t>(getPosition());
678  position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
679  if (position >= op.getNumDpsInits() || position < 0) {
680  return emitSilenceableError()
681  << "position " << rawPosition
682  << " overflows the number of results(ints) of the payload operation";
683  }
685 }
686 
688  if ((getAny() || getSingle()) ^
689  isa<TransformHandleTypeInterface>(getResult().getType())) {
690  return emitOpError() << "expects either the any/single keyword or the type "
691  "value handle result type";
692  }
693  if (getAny() && getSingle()) {
694  return emitOpError() << "'any' and 'single' are mutually exclusive";
695  }
696  return success();
697 }
698 
699 //===----------------------------------------------------------------------===//
700 // MatchStructuredYieldOp
701 //===----------------------------------------------------------------------===//
702 
703 void transform::MatchStructuredYieldOp::getEffects(
705  onlyReadsHandle(getHandlesMutable(), effects);
706  onlyReadsPayload(effects);
707 }
708 
709 void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
710  OperationState &state) {
711  build(builder, state, ValueRange());
712 }
713 
714 #define GET_OP_CLASSES
715 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
The state maintained across applications of various ops implementing the TransformOpInterface.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:220
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op, Value structuredOpHandle)
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.