MLIR  19.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
29 #include "mlir/IR/AffineMap.h"
32 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/PatternMatch.h"
37 
38 #include "llvm/ADT/DenseMap.h"
39 #include "llvm/ADT/SmallSet.h"
40 #include "llvm/ADT/StringSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/FormatVariadic.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include <optional>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
52  int64_t dim) {
53  auto type = cast<ShapedType>(v.getType());
54  if (!type.isDynamicDim(dim))
55  return builder.getIndexAttr(type.getDimSize(dim));
56 
57  return getAsOpFoldResult(
59  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
60  return builder.create<tensor::DimOp>(loc, v, dim);
61  })
62  .Case<MemRefType>([&](MemRefType t) -> Value {
63  return builder.create<memref::DimOp>(loc, v, dim);
64  }));
65 }
66 
67 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
68 /// `source`.
69 static Value getSlice(OpBuilder &b, Location loc, Value source,
70  ArrayRef<OpFoldResult> offsets,
72  ArrayRef<OpFoldResult> strides) {
73  return TypeSwitch<Type, Value>(source.getType())
74  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
75  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
76  strides);
77  })
78  .Case<MemRefType>([&](MemRefType type) -> Value {
79  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
80  strides);
81  })
82  .Default([&](Type t) { return nullptr; });
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Helper functions
87 //===----------------------------------------------------------------------===//
88 
90  int64_t dim) {
91  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
92  return b.createOrFold<memref::DimOp>(loc, source, dim);
93  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
94  return b.createOrFold<tensor::DimOp>(loc, source, dim);
95  llvm_unreachable("Expected MemRefType or TensorType");
96 }
97 
99  int64_t dim) {
100  auto shapedType = llvm::cast<ShapedType>(source.getType());
101  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
102  return createOrFoldDimOp(b, loc, source, dim);
103  return b.getIndexAttr(shapedType.getDimSize(dim));
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // Support for named Linalg ops defined in ods-gen.
108 //===----------------------------------------------------------------------===//
109 
112 
113 /// Fills the region of a structured operation using the provided
114 /// `regionBuilder`. The method is used by both named structured ops created by
115 /// ods-gen and by manually defined C++ ops. It is called by both builders and
116 /// parsers and creates a block with arguments corresponding to the elemental
117 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
118 /// ShapedType.
119 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
120  TypeRange inputTypes, TypeRange outputTypes,
122  RegionBuilderFn regionBuilder) {
123  assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
124 
125  SmallVector<Type, 8> argTypes;
126  SmallVector<Location, 8> argLocs;
127  for (auto containers : {inputTypes, outputTypes}) {
128  for (auto t : containers) {
129  argTypes.push_back(
130  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
131 
132  // TODO: Pass in a proper location here.
133  argLocs.push_back(opBuilder.getUnknownLoc());
134  }
135  }
136 
137  // RAII.
138  OpBuilder::InsertionGuard guard(opBuilder);
139  Block *body =
140  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
141 
142  opBuilder.setInsertionPointToStart(body);
143  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
144  regionBuilder(b, *body, attrs);
145 
146  // indexing_maps is an auto-generated method.
147 
148  // iterator_types is an auto-generated method.
149 }
150 
151 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
152 /// The result types are derived automatically if `resultTensorTypes` is none.
153 /// The body of the operation is filled using `regionBuilder`. All ods-gen
154 /// created structured operations use the method to implement their builders.
156  std::optional<TypeRange> resultTensorTypes,
157  ValueRange inputs, ValueRange outputs,
158  ArrayRef<NamedAttribute> attributes,
159  RegionBuilderFn regionBuilder) {
160  // Derive the result types if needed.
161  SmallVector<Type> derivedResultTypes =
162  resultTensorTypes.value_or(TypeRange());
163  if (!resultTensorTypes)
164  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
165  llvm::IsaPred<RankedTensorType>);
166 
167  state.addOperands(inputs);
168  state.addOperands(outputs);
169  state.addTypes(derivedResultTypes);
170  state.addAttributes(attributes);
171  state.addAttribute(
172  "operandSegmentSizes",
173  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
174  static_cast<int32_t>(outputs.size())}));
175 
176  // Create and fill the region of the structured operation.
177  Region &region = *state.addRegion();
178  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
179  state.attributes.getAttrs(), regionBuilder);
180 }
181 
182 /// Common parsing used for both named structured ops created by ods-gen and by
183 /// manually defined C++ ops. Does not handle regions.
184 static ParseResult
186  SmallVectorImpl<Type> &inputTypes,
187  SmallVectorImpl<Type> &outputTypes,
188  bool addOperandSegmentSizes = true) {
189  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
191  outputsOperands;
192 
193  if (succeeded(parser.parseOptionalLess())) {
194  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
195  return failure();
196  }
197  attrsLoc = parser.getCurrentLocation();
198  if (parser.parseOptionalAttrDict(result.attributes))
199  return failure();
200 
201  if (succeeded(parser.parseOptionalKeyword("ins"))) {
202  if (parser.parseLParen())
203  return failure();
204 
205  inputsOperandsLoc = parser.getCurrentLocation();
206  if (parser.parseOperandList(inputsOperands) ||
207  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
208  return failure();
209  }
210 
211  if (succeeded(parser.parseOptionalKeyword("outs"))) {
212  outputsOperandsLoc = parser.getCurrentLocation();
213  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
214  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
215  return failure();
216  }
217 
218  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
219  result.operands) ||
220  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
221  result.operands))
222  return failure();
223 
224  if (addOperandSegmentSizes) {
225  // This is a bit complex because we're trying to be backward compatible with
226  // operation syntax that mix the inherent attributes and the discardable
227  // ones in the same dictionary. If the properties are used, we append the
228  // operandSegmentSizes there directly. Otherwise we append it to the
229  // discardable attributes dictionary where it is handled by the generic
230  // Operation::create(...) method.
231  if (result.propertiesAttr) {
232  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
233  attrs.append("operandSegmentSizes",
235  {static_cast<int32_t>(inputsOperands.size()),
236  static_cast<int32_t>(outputsOperands.size())}));
237  result.propertiesAttr = attrs.getDictionary(parser.getContext());
238  } else {
239  result.addAttribute("operandSegmentSizes",
241  {static_cast<int32_t>(inputsOperands.size()),
242  static_cast<int32_t>(outputsOperands.size())}));
243  }
244  }
245  if (!result.propertiesAttr) {
246  std::optional<RegisteredOperationName> info =
247  result.name.getRegisteredInfo();
248  if (info) {
249  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
250  return parser.emitError(attrsLoc)
251  << "'" << result.name.getStringRef() << "' op ";
252  })))
253  return failure();
254  }
255  }
256  return success();
257 }
258 
260  ValueRange outputs) {
261  if (!inputs.empty())
262  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
263  if (!outputs.empty())
264  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // Specific parsing and printing for named structured ops created by ods-gen.
269 //===----------------------------------------------------------------------===//
270 
272  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
273  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
274  RegionBuilderFn regionBuilder) {
275  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
276  return parser.emitError(
277  parser.getCurrentLocation(),
278  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
279  "region expects {0} args, got {1}",
280  numRegionArgs, inputTypes.size() + outputTypes.size()));
281  }
282 
283  OpBuilder opBuilder(parser.getContext());
284  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
285  regionBuilder);
286  return success();
287 }
288 
289 static ParseResult
291  SmallVectorImpl<Type> &resultTypes) {
292  if (parser.parseOptionalArrowTypeList(resultTypes))
293  return failure();
294  return success();
295 }
296 
298  OperationState &result,
299  unsigned numRegionArgs,
300  RegionBuilderFn regionBuilder) {
301  // TODO: Enable when ods-gen supports captures.
302  SmallVector<Type, 1> inputTypes, outputTypes;
303  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
304  return failure();
305 
306  // TODO: consider merging results parsing into region parsing.
307  // Need to wait for declarative assembly resolution to decide.
308  SmallVector<Type, 1> outputTensorsTypes;
309  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
310  return failure();
311  result.addTypes(outputTensorsTypes);
312 
313  std::unique_ptr<Region> region = std::make_unique<Region>();
314  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
315  outputTypes, result.attributes.getAttrs(),
316  regionBuilder))
317  return failure();
318  result.addRegion(std::move(region));
319 
320  return success();
321 }
322 
324  TypeRange resultTypes) {
325  if (resultTypes.empty())
326  return;
327  p.printOptionalArrowTypeList(resultTypes);
328 }
329 
331  ValueRange inputs, ValueRange outputs) {
333  op->getAttrs(),
334  /*elidedAttrs=*/{"operandSegmentSizes",
335  // See generated code in
336  // LinalgNamedStructuredOps.yamlgen.cpp.inc
337  "linalg.memoized_indexing_maps"});
338 
339  // Printing is shared with generic ops, except for the region and
340  // attributes.
341  printCommonStructuredOpParts(p, inputs, outputs);
342 
343  // Results printing.
345 
346  // Region is elided.
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // Region builder helper.
351 // TODO: Move this to a utility library.
352 // The public methods on this class are referenced directly from generated code.
353 // Helper build the unary, binary, and type conversion functions defined by the
354 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
355 // class.
356 //
357 // Implementations of the math functions must be polymorphic over numeric types,
358 // internally performing necessary casts. If the function application makes no
359 // sense, then the only recourse is to assert and return nullptr. This can be
360 // extended later if it becomes possible to fail construction of the region. The
361 // invariant should be enforced at a higher level.
362 //
363 // TODO: These helpers are currently type polymorphic over the class of integer
364 // and floating point types, but they will not internally cast within bit
365 // widths of a class (mixed precision such as i8->i32) or across classes
366 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
367 // to be handled with care and work is being considered to extend the op
368 // language to make such cases explicit. In the mean-time, violating this will
369 // fail verification, which is deemed acceptable.
370 //===----------------------------------------------------------------------===//
371 
372 namespace {
373 
374 class RegionBuilderHelper {
375 public:
376  RegionBuilderHelper(OpBuilder &builder, Block &block)
377  : builder(builder), block(block) {}
378 
379  // Build the unary functions defined by OpDSL.
380  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
381  if (!isFloatingPoint(arg))
382  llvm_unreachable("unsupported non numeric type");
383  OpBuilder::InsertionGuard g(builder);
384  builder.setInsertionPointToEnd(&block);
385  switch (unaryFn) {
386  case UnaryFn::exp:
387  return builder.create<math::ExpOp>(arg.getLoc(), arg);
388  case UnaryFn::log:
389  return builder.create<math::LogOp>(arg.getLoc(), arg);
390  case UnaryFn::abs:
391  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
392  case UnaryFn::ceil:
393  return builder.create<math::CeilOp>(arg.getLoc(), arg);
394  case UnaryFn::floor:
395  return builder.create<math::FloorOp>(arg.getLoc(), arg);
396  case UnaryFn::negf:
397  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
398  }
399  llvm_unreachable("unsupported unary function");
400  }
401 
402  // Build the binary functions defined by OpDSL.
403  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
404  bool allComplex = isComplex(arg0) && isComplex(arg1);
405  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
406  bool allInteger = isInteger(arg0) && isInteger(arg1);
407  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
408  arg1.getType().getIntOrFloatBitWidth() == 1;
409  if (!allComplex && !allFloatingPoint && !allInteger)
410  llvm_unreachable("unsupported non numeric type");
411  OpBuilder::InsertionGuard g(builder);
412  builder.setInsertionPointToEnd(&block);
413  switch (binaryFn) {
414  case BinaryFn::add:
415  if (allComplex)
416  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
417  if (allFloatingPoint)
418  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
419  if (allBool)
420  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
421  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
422  case BinaryFn::sub:
423  if (allComplex)
424  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
425  if (allFloatingPoint)
426  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
427  if (allBool)
428  llvm_unreachable("unsupported operation: sub with bools");
429  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
430  case BinaryFn::mul:
431  if (allComplex)
432  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
433  if (allFloatingPoint)
434  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
435  if (allBool)
436  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
437  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
438  case BinaryFn::div:
439  if (allComplex)
440  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
441  if (allFloatingPoint)
442  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
443  if (allBool)
444  llvm_unreachable("unsupported operation: div with bools");
445  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
446  case BinaryFn::div_unsigned:
447  if (!allInteger || allBool)
448  llvm_unreachable("unsupported operation: unsigned div not on uint");
449  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
450  case BinaryFn::max_signed:
451  assert(!allComplex);
452  if (allFloatingPoint)
453  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
454  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
455  case BinaryFn::min_signed:
456  assert(!allComplex);
457  if (allFloatingPoint)
458  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
459  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
460  case BinaryFn::max_unsigned:
461  assert(!allComplex);
462  if (allFloatingPoint)
463  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
464  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
465  case BinaryFn::min_unsigned:
466  assert(!allComplex);
467  if (allFloatingPoint)
468  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
469  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
470  }
471  llvm_unreachable("unsupported binary function");
472  }
473 
474  // Build the type functions defined by OpDSL.
475  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
476  switch (typeFn) {
477  case TypeFn::cast_signed:
478  return cast(toType, operand, false);
479  case TypeFn::cast_unsigned:
480  return cast(toType, operand, true);
481  }
482  llvm_unreachable("unsupported type conversion function");
483  }
484 
485  void yieldOutputs(ValueRange values) {
486  OpBuilder::InsertionGuard g(builder);
487  builder.setInsertionPointToEnd(&block);
488  Location loc = builder.getUnknownLoc();
489  builder.create<YieldOp>(loc, values);
490  }
491 
492  Value constant(const std::string &value) {
493  OpBuilder::InsertionGuard g(builder);
494  builder.setInsertionPointToEnd(&block);
495  Location loc = builder.getUnknownLoc();
496  Attribute valueAttr = parseAttribute(value, builder.getContext());
497  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
498  }
499 
500  Value index(int64_t dim) {
501  OpBuilder::InsertionGuard g(builder);
502  builder.setInsertionPointToEnd(&block);
503  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
504  }
505 
506  Type getIntegerType(unsigned width) {
507  return IntegerType::get(builder.getContext(), width);
508  }
509 
510  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
511  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
512 
513 private:
514  // Generates operations to cast the given operand to a specified type.
515  // If the cast cannot be performed, a warning will be issued and the
516  // operand returned as-is (which will presumably yield a verification
517  // issue downstream).
518  Value cast(Type toType, Value operand, bool isUnsignedCast) {
519  OpBuilder::InsertionGuard g(builder);
520  builder.setInsertionPointToEnd(&block);
521  auto loc = operand.getLoc();
522  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
523  }
524 
525  bool isComplex(Value value) {
526  return llvm::isa<ComplexType>(value.getType());
527  }
528  bool isFloatingPoint(Value value) {
529  return llvm::isa<FloatType>(value.getType());
530  }
531  bool isInteger(Value value) {
532  return llvm::isa<IntegerType>(value.getType());
533  }
534 
535  OpBuilder &builder;
536  Block &block;
537 };
538 
539 } // namespace
540 
541 //===----------------------------------------------------------------------===//
542 // CopyOp
543 //===----------------------------------------------------------------------===//
544 
545 namespace {
546 
547 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
549  LogicalResult matchAndRewrite(CopyOp copyOp,
550  PatternRewriter &rewriter) const override {
551  if (copyOp.getInputs() != copyOp.getOutputs())
552  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
553  if (copyOp.hasPureBufferSemantics())
554  rewriter.eraseOp(copyOp);
555  else
556  rewriter.replaceOp(copyOp, copyOp.getInputs());
557 
558  return success();
559  }
560 };
561 
562 } // namespace
563 
564 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
565  MLIRContext *context) {
566  results.add<EraseSelfCopy>(context);
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // FillOp
571 //===----------------------------------------------------------------------===//
572 
573 namespace {
574 
575 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
576 ///
577 /// For such op chains, we can create new linalg.fill ops with the result
578 /// type of the tensor.expand/collapse_shape op.
579 template <typename TensorReshapeOp>
580 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
582  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
583  PatternRewriter &rewriter) const override {
584  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
585  if (!oldFill)
586  return failure();
587 
588  Location loc = oldFill.getLoc();
589  auto newInit = rewriter.create<TensorReshapeOp>(
590  loc, reshapeOp.getResultType(), oldFill.output(),
591  reshapeOp.getReassociation());
592  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
593  ValueRange{newInit});
594 
595  return success();
596  }
597 };
598 
599 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
600 /// filling value are the same.
601 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
603 
604  LogicalResult matchAndRewrite(tensor::PadOp padOp,
605  PatternRewriter &rewriter) const override {
606  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
607  if (!fillOp)
608  return failure();
609 
610  // We can only fold if the padding value is the same as the original
611  // filling value.
612  Value padValue = padOp.getConstantPaddingValue();
613  if (!padValue || fillOp.value() != padValue)
614  return failure();
615 
616  ReifiedRankedShapedTypeDims reifiedShape;
617  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
618  return rewriter.notifyMatchFailure(
619  padOp, "failed to reify tensor.pad op result shape");
620 
621  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
622  padOp.getLoc(), reifiedShape.front(),
623  padOp.getResultType().getElementType());
624  Value replacement =
625  rewriter
626  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
627  ValueRange{emptyTensor})
628  .getResult(0);
629  if (replacement.getType() != padOp.getResultType()) {
630  replacement = rewriter.create<tensor::CastOp>(
631  fillOp.getLoc(), padOp.getResultType(), replacement);
632  }
633  rewriter.replaceOp(padOp, replacement);
634  return success();
635  }
636 };
637 
638 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
639 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
640 /// filling value are the same.
641 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
643 
644  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
645  PatternRewriter &rewriter) const override {
646  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
647  if (!srcPadOp)
648  return failure();
649 
650  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
651  return failure();
652 
653  // Walk back the tensor.insert_slice chain and find the first destination
654  // value at the start of the chain.
655  Value firstDest = insertOp.getDest();
656  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
657  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
658  return failure();
659 
660  // Make sure the range of values accessed are disjoint. Without this, we
661  // cannot fold tensor.pad away.
662  bool disjoint = false;
663  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
664  // If the dimension has dynamic offset/size, we cannot guarantee
665  // disjoint. So just skip it.
666  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
667  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
668  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
669  continue;
670 
671  // Get the range start and end, inclusively for both.
672  int64_t prevStart = prevOp.getStaticOffset(i);
673  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
674  prevOp.getStaticStride(i);
675  int64_t nextStart = insertOp.getStaticOffset(i);
676  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
677  insertOp.getStaticStride(i);
678  if (prevEnd < nextStart || nextEnd < prevStart) {
679  disjoint = true;
680  break;
681  }
682  }
683 
684  if (!disjoint)
685  break;
686  firstDest = prevOp.getDest();
687  }
688 
689  // Check whether the first destination is a fill op. For overlapped cases,
690  // this also cannot be true.
691  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
692  if (!dstFillOp)
693  return failure();
694 
695  // We can only fold if the padding value is the same as the original
696  // filling value.
697  Value padValue = srcPadOp.getConstantPaddingValue();
698  if (!padValue || dstFillOp.value() != padValue)
699  return failure();
700 
701  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
702  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
703 
704  Location loc = insertOp.getLoc();
705  MLIRContext *context = getContext();
706 
707  AffineExpr sym0, sym1;
708  bindSymbols(context, sym0, sym1);
709  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
710 
711  // Calculate the new offsets for the insert. It should be the old offsets
712  // plus low padding sizes.
713  SmallVector<OpFoldResult, 4> newOffsets;
714  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
715  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
716  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
717  }
718 
719  RankedTensorType srcPadType = srcPadOp.getSourceType();
721  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
722  if (srcPadType.isDynamicDim(i)) {
723  newSizes.push_back(
724  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
725  .getResult());
726  } else {
727  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
728  }
729  }
730 
731  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
732  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
733  newSizes, insertOp.getMixedStrides());
734  return success();
735  }
736 };
737 
738 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
739 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
740 public:
742 
743  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
744  PatternRewriter &rewriter) const override {
745  // See if tensor input of tensor.extract op is the result of a linalg.fill
746  // op.
747  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
748  if (!fillOp)
749  return failure();
750 
751  // Get scalar input operand of linalg.fill op.
752  Value extractedScalar = fillOp.getInputs()[0];
753 
754  // Replace tensor.extract op with scalar value used to fill the tensor.
755  rewriter.replaceOp(extractOp, extractedScalar);
756  return success();
757  }
758 };
759 
760 /// Folds pack(fill) into a single fill op if
761 /// 1. The pack op does not have padding value, or
762 /// 2. The filled value and padding value are the same.
763 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
764  tensor::PackOp packOp) {
765  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
766  if (!fillOp)
767  return failure();
768 
769  if (auto paddingValue = packOp.getPaddingValue())
770  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
771  return failure();
772 
773  Value packOpDest = packOp.getDest();
774  if (!packOpDest.hasOneUse())
775  return failure();
776 
777  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
778  packOp.getDest());
779 }
780 
781 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
782 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
783 public:
784  FoldFillWithPack(MLIRContext *context)
785  : OpRewritePattern<tensor::PackOp>(context) {}
786 
787  LogicalResult matchAndRewrite(tensor::PackOp packOp,
788  PatternRewriter &rewriter) const override {
789  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
790  if (failed(fillOp))
791  return failure();
792  rewriter.replaceOp(packOp, fillOp.value().result());
793  return success();
794  }
795 };
796 
797 /// Fold fill with copy.
798 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
800 
801  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
802  PatternRewriter &rewriter) const override {
803  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
804  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
805  fillOp.getInputs(),
806  copyOp.getOutputs());
807  return success();
808  }
809  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
810  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
811  fillOp.getOutputs());
812  return success();
813  }
814  return failure();
815  }
816 };
817 
818 /// Fold fill with transpose.
819 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
821 
822  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
823  PatternRewriter &rewriter) const override {
824  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
825  rewriter.replaceOpWithNewOp<FillOp>(
826  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
827  transposeOp.getDpsInitOperand(0)->get());
828  return success();
829  }
830  return failure();
831  }
832 };
833 
834 } // namespace
835 
836 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
837  MLIRContext *context) {
838  results
839  .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
840  FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
841  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
842  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
843 }
844 
845 //===----------------------------------------------------------------------===//
846 // GenericOp
847 //===----------------------------------------------------------------------===//
848 
849 static void buildGenericRegion(
850  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
851  ValueRange outputs,
852  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
853  SmallVector<Type, 4> blockArgTypes;
854  SmallVector<Location, 4> blockArgLocs;
855  for (ValueRange container : {inputs, outputs}) {
856  for (Value v : container) {
857  Type t = v.getType();
858  blockArgTypes.push_back(
859  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
860  blockArgLocs.push_back(v.getLoc());
861  }
862  }
863 
864  OpBuilder::InsertionGuard guard(builder);
865  Block *bodyBlock =
866  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
867  bodyBuild(builder, loc, bodyBlock->getArguments());
868 }
869 
870 void GenericOp::getAsmBlockArgumentNames(Region &region,
871  OpAsmSetValueNameFn setNameFn) {
872  for (Value v : getRegionInputArgs())
873  setNameFn(v, "in");
874  for (Value v : getRegionOutputArgs())
875  setNameFn(v, "out");
876 }
877 
878 void GenericOp::build(
879  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
880  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
881  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
882  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
883  ArrayRef<NamedAttribute> attributes) {
884  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
885  iteratorTypes, doc, libraryCall);
886  result.addAttributes(attributes);
887  if (bodyBuild)
888  buildGenericRegion(builder, result.location, *result.regions.front(),
889  inputs, outputs, bodyBuild);
890 }
891 
892 void GenericOp::build(
893  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
894  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
895  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
896  StringRef libraryCall,
897  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
898  ArrayRef<NamedAttribute> attributes) {
899  build(builder, result, resultTensorTypes, inputs, outputs,
900  builder.getAffineMapArrayAttr(indexingMaps),
901  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
902  iteratorTypes,
903  [&](utils::IteratorType iter) -> mlir::Attribute {
904  return IteratorTypeAttr::get(builder.getContext(), iter);
905  }))),
906  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
907  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
908  bodyBuild, attributes);
909 }
910 
911 void GenericOp::build(
912  OpBuilder &builder, OperationState &result, ValueRange inputs,
913  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
914  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
915  StringRef libraryCall,
916  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
917  ArrayRef<NamedAttribute> attributes) {
918  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
919  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
920 }
921 
922 void GenericOp::build(
923  OpBuilder &builder, OperationState &result, ValueRange inputs,
924  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
925  ArrayRef<utils::IteratorType> iteratorTypes,
926  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
927  ArrayRef<NamedAttribute> attributes) {
928  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
929  /*doc=*/"",
930  /*libraryCall=*/"", bodyBuild, attributes);
931 }
932 
933 void GenericOp::build(
934  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
935  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
936  ArrayRef<utils::IteratorType> iteratorTypes,
937  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
938  ArrayRef<NamedAttribute> attributes) {
939  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
940  iteratorTypes,
941  /*doc=*/"",
942  /*libraryCall=*/"", bodyBuild, attributes);
943 }
944 
946  p << " ";
947 
948  // Print extra attributes.
949  auto genericAttrNames = linalgTraitAttrNames();
950 
951  llvm::StringSet<> genericAttrNamesSet;
952  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
953  SmallVector<NamedAttribute, 8> genericAttrs;
954  for (auto attr : (*this)->getAttrs()) {
955  if (attr.getName() == getIteratorTypesAttrName()) {
956  auto iteratorTypes =
957  llvm::cast<ArrayAttr>(attr.getValue())
958  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
959  // Convert IteratorType enums into the string representation. This is
960  // needed, because tests still use the old format when 'iterator_types'
961  // attribute is represented as an array of strings.
962  // TODO: Remove this conversion once tests are fixed.
963  SmallVector<Attribute> iteratorTypeNames =
964  llvm::to_vector(llvm::map_range(
965  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
966  return StringAttr::get(getContext(), stringifyIteratorType(t));
967  }));
968 
969  genericAttrs.emplace_back(
970  getIteratorTypesAttrName(),
971  ArrayAttr::get(getContext(), iteratorTypeNames));
972  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
973  genericAttrs.push_back(attr);
974  }
975  }
976  if (!genericAttrs.empty()) {
977  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
978  p << genericDictAttr;
979  }
980 
981  // Printing is shared with named ops, except for the region and attributes
982  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
983 
984  genericAttrNames.push_back("operandSegmentSizes");
985  genericAttrNamesSet.insert(genericAttrNames.back());
986 
987  bool hasExtraAttrs = false;
988  for (NamedAttribute n : (*this)->getAttrs()) {
989  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
990  break;
991  }
992  if (hasExtraAttrs) {
993  p << " attrs = ";
994  p.printOptionalAttrDict((*this)->getAttrs(),
995  /*elidedAttrs=*/genericAttrNames);
996  }
997 
998  // Print region.
999  if (!getRegion().empty()) {
1000  p << ' ';
1001  p.printRegion(getRegion());
1002  }
1003 
1004  // Print results.
1005  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1006 }
1007 
1009  DictionaryAttr dictAttr;
1010  // Parse the core linalg traits that must check into a dictAttr.
1011  // The name is unimportant as we will overwrite result.attributes.
1012  // The core linalg traits must contain the information necessary to pass the
1013  // verifier.
1014  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1015  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1016  return failure();
1017  result.attributes.assign(dictAttr.getValue().begin(),
1018  dictAttr.getValue().end());
1019 
1020  // Convert array of string into an array of IteratorType enums. This is
1021  // needed, because tests still use the old format when 'iterator_types'
1022  // attribute is represented as an array of strings.
1023  // TODO: Remove this conversion once tests are fixed.
1024  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1025  result.attributes.get(getIteratorTypesAttrName(result.name)));
1026  if (!iteratorTypes) {
1027  return parser.emitError(attributeLocation)
1028  << "expected " << getIteratorTypesAttrName(result.name)
1029  << " array attribute";
1030  }
1031 
1032  SmallVector<Attribute> iteratorTypeAttrs;
1033 
1034  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1035  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1036  if (!maybeIteratorType.has_value())
1037  return parser.emitError(parser.getCurrentLocation())
1038  << "unexpected iterator_type (" << s << ")";
1039 
1040  iteratorTypeAttrs.push_back(
1041  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1042  }
1043  result.attributes.set(getIteratorTypesAttrName(result.name),
1044  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1045 
1046  // Parsing is shared with named ops, except for the region.
1047  SmallVector<Type, 1> inputTypes, outputTypes;
1048  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1049  return failure();
1050 
1051  // Optional attributes may be added.
1052  if (succeeded(parser.parseOptionalKeyword("attrs")))
1053  if (failed(parser.parseEqual()) ||
1054  failed(parser.parseOptionalAttrDict(result.attributes)))
1055  return failure();
1056 
1057  std::unique_ptr<Region> region = std::make_unique<Region>();
1058  if (parser.parseRegion(*region, {}))
1059  return failure();
1060  result.addRegion(std::move(region));
1061 
1062  // Generic ops may specify that a subset of its outputs are tensors. Such
1063  // outputs are specified in the result type.
1064  // TODO: may need to move output parsing before region parsing.
1065  // Need to wait for declarative assembly resolution to decide.
1066  SmallVector<Type, 1> outputTensorsTypes;
1067  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1068  return failure();
1069  result.addTypes(outputTensorsTypes);
1070 
1071  return success();
1072 }
1073 
1076  &effects,
1077  ValueRange results, const ValueRange inputOperands,
1078  ValueRange outputOperands) {
1079  for (auto operand : inputOperands) {
1080  if (!llvm::isa<MemRefType>(operand.getType()))
1081  continue;
1082  effects.emplace_back(MemoryEffects::Read::get(), operand,
1084  }
1085  for (auto operand : outputOperands) {
1086  if (!llvm::isa<MemRefType>(operand.getType()))
1087  continue;
1088  effects.emplace_back(MemoryEffects::Read::get(), operand,
1090  effects.emplace_back(MemoryEffects::Write::get(), operand,
1092  }
1093 }
1094 
1095 void GenericOp::getEffects(
1097  &effects) {
1098  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1099  getDpsInits());
1100 }
1101 
1102 LogicalResult GenericOp::verify() { return success(); }
1103 
1104 namespace {
1105 
1106 /// Remove any linalg operation (on tensors) that are just copying
1107 /// the values from inputs to the results. Requirements are
1108 /// 1) All iterator types are parallel
1109 /// 2) The body contains just a yield operation with the yielded values being
1110 /// the arguments corresponding to the operands.
1111 template <typename OpTy>
1112 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1114 
1115  LogicalResult matchAndRewrite(OpTy linalgOp,
1116  PatternRewriter &rewriter) const override {
1117  // Check all indexing maps are identity.
1118  if (llvm::any_of(linalgOp.getIndexingMapsArray(),
1119  [](AffineMap map) { return !map.isIdentity(); }))
1120  return failure();
1121 
1122  // Check that the body of the linalg operation is just a linalg.yield
1123  // operation.
1124  Block &body = linalgOp->getRegion(0).front();
1125  if (!llvm::hasSingleElement(body))
1126  return failure();
1127  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1128  if (!yieldOp)
1129  return failure();
1130 
1131  // In the buffer case, we need to check exact buffer equality.
1132  if (linalgOp.hasPureBufferSemantics()) {
1133  if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1134  linalgOp.getDpsInputOperand(0)->get() ==
1135  linalgOp.getDpsInitOperand(0)->get()) {
1136  rewriter.eraseOp(linalgOp);
1137  return success();
1138  }
1139  return failure();
1140  }
1141 
1142  // Mixed semantics is not supported yet.
1143  if (!linalgOp.hasPureTensorSemantics())
1144  return failure();
1145 
1146  // Get the argument number of the returned values. That is the operand
1147  // number to use for replacing uses of this operation.
1148  SmallVector<Value> returnedArgs;
1149  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1150  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1151  if (!yieldArg || yieldArg.getOwner() != &body)
1152  return failure();
1153  unsigned argumentNumber = yieldArg.getArgNumber();
1154  Value returnedArg = linalgOp->getOperand(argumentNumber);
1155  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1156  // The input can have a different type than the result, e.g. a dynamic
1157  // input dimension can be turned into a static output dimension.
1158  Type returnType = returnedArg.getType();
1159  if (returnType != resultType) {
1160  // Distinguish between sparse conversion or dense tensor casting.
1161  // TODO: unify the two ops?
1162  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1164  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1165  linalgOp.getLoc(), resultType, returnedArg);
1166  else {
1167  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1168  resultType))
1169  return failure();
1170  returnedArg = rewriter.create<tensor::CastOp>(
1171  linalgOp.getLoc(), resultType, returnedArg);
1172  }
1173  }
1174  returnedArgs.push_back(returnedArg);
1175  }
1176 
1177  if (returnedArgs.size() != linalgOp->getNumResults())
1178  return failure();
1179  rewriter.replaceOp(linalgOp, returnedArgs);
1180  return success();
1181  }
1182 };
1183 
1184 } // namespace
1185 
1186 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1187  MLIRContext *context) {
1188  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1189 }
1190 
1191 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1192  return memref::foldMemRefCast(*this);
1193 }
1194 
1195 //===----------------------------------------------------------------------===//
1196 // MapOp
1197 //===----------------------------------------------------------------------===//
1198 
1200  OpAsmParser &parser, OperationState &result,
1201  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1202  nullptr) {
1203  // Parse `ins` and `outs`.
1204  SmallVector<Type, 4> inputTypes, outputTypes;
1205  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1206  /*addOperandSegmentSizes=*/false))
1207  return failure();
1208 
1209  // Add result types.
1210  for (Type outputType : outputTypes) {
1211  if (llvm::isa<RankedTensorType>(outputType))
1212  result.addTypes(outputType);
1213  }
1214 
1215  // Parse required attributes.
1216  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1217  return failure();
1218 
1219  // Parse optional attributes.
1220  if (parser.parseOptionalAttrDict(result.attributes))
1221  return failure();
1222  return success();
1223 }
1224 
1225 void MapOp::getAsmBlockArgumentNames(Region &region,
1226  OpAsmSetValueNameFn setNameFn) {
1227  for (Value v : getRegionInputArgs())
1228  setNameFn(v, "in");
1229 }
1230 
1231 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1232  if (!getResults().empty())
1233  setNameFn(getResults().front(), "mapped");
1234 }
1235 
1236 void MapOp::build(
1237  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1238  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1239  ArrayRef<NamedAttribute> attributes) {
1240  build(builder, result, TypeRange{}, inputs, init);
1241  result.addAttributes(attributes);
1242 
1243  // Add output types for `RankedTensorType` output arguments.
1244  Type initType = init.getType();
1245  if (llvm::isa<RankedTensorType>(initType))
1246  result.addTypes(initType);
1247 
1248  if (bodyBuild)
1249  buildGenericRegion(builder, result.location, *result.regions.front(),
1250  inputs, /*outputs=*/{}, bodyBuild);
1251 }
1252 
1253 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1254  const OperationName &payloadOpName,
1255  const NamedAttrList &payloadOpAttrs,
1256  ArrayRef<Value> operands,
1257  bool initFirst = false) {
1258  OpBuilder b(parser.getContext());
1259  Region *body = result.addRegion();
1260  Block &block = body->emplaceBlock();
1261  b.setInsertionPointToStart(&block);
1262  SmallVector<Value> bbArgs;
1263  for (auto &operand : operands) {
1264  block.addArgument(
1265  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1266  b.getUnknownLoc());
1267  }
1268  SmallVector<Value> payloadOpOperands;
1269  // If initFirst flag is enabled, we consider init as the first position of
1270  // payload operands.
1271  if (initFirst) {
1272  payloadOpOperands.push_back(block.getArguments().back());
1273  for (const auto &arg : block.getArguments().drop_back())
1274  payloadOpOperands.push_back(arg);
1275  } else {
1276  payloadOpOperands = {block.getArguments().begin(),
1277  block.getArguments().end()};
1278  }
1279 
1280  Operation *payloadOp = b.create(
1281  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1282  payloadOpOperands,
1283  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1284  .getElementType()},
1285  payloadOpAttrs);
1286  b.create<YieldOp>(result.location, payloadOp->getResults());
1287 }
1288 
1290  std::optional<OperationName> payloadOpName;
1291  NamedAttrList payloadOpAttrs;
1292  if (succeeded(parser.parseOptionalLBrace())) {
1293  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1294  if (failed(operationName))
1295  return failure();
1296  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1297  return failure();
1298  payloadOpName = operationName.value();
1299  if (parser.parseRBrace())
1300  return failure();
1301  }
1302 
1303  if (parseDstStyleOp(parser, result))
1304  return failure();
1305 
1306  if (payloadOpName.has_value()) {
1307  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1308  ArrayRef(result.operands).drop_back());
1309  } else {
1311  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1312  /*allowType=*/true, /*allowAttrs=*/true)) {
1313  return failure();
1314  }
1315  Region *body = result.addRegion();
1316  if (parser.parseRegion(*body, regionArgs))
1317  return failure();
1318  }
1319  return success();
1320 }
1321 
1322 // Retrieve the operation from the body, if it is the only one (except
1323 // yield) and if it gets the same amount of arguments as the body does.
1324 // If initFirst flag is enabled, we check that init takes the first position in
1325 // operands of payload.
1326 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1327  if (body->getOperations().size() != 2)
1328  return nullptr;
1329  Operation &payload = body->getOperations().front();
1330  assert(isa<YieldOp>(body->getOperations().back()));
1331 
1332  if (payload.getNumOperands() == 0 ||
1333  payload.getNumOperands() != body->getNumArguments())
1334  return nullptr;
1335  if (initFirst) {
1336  // check init
1337  if (payload.getOperands().back() != body->getArgument(0))
1338  return nullptr;
1339  // check rest
1340  for (const auto &[operand, bbArg] :
1341  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1342  if (bbArg != operand)
1343  return nullptr;
1344  }
1345  } else {
1346  for (const auto &[operand, bbArg] :
1347  llvm::zip(payload.getOperands(), body->getArguments())) {
1348  if (bbArg != operand)
1349  return nullptr;
1350  }
1351  }
1352  return &payload;
1353 }
1354 
1355 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1356  SmallVector<StringRef> elidedAttrs;
1357  std::string attrToElide;
1358  p << " { " << payloadOp->getName().getStringRef();
1359  for (const auto &attr : payloadOp->getAttrs()) {
1360  auto fastAttr =
1361  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1362  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1363  attrToElide = attr.getName().str();
1364  elidedAttrs.push_back(attrToElide);
1365  break;
1366  }
1367  }
1368  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1369  p << " }";
1370 }
1371 
1372 void MapOp::print(OpAsmPrinter &p) {
1373  Block *mapper = getBody();
1374  Operation *payloadOp = findPayloadOp(mapper);
1375  if (payloadOp) {
1376  printShortForm(p, payloadOp);
1377  }
1378 
1379  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1380  p.printOptionalAttrDict((*this)->getAttrs());
1381 
1382  if (!payloadOp) {
1383  // Print region if the payload op was not detected.
1384  p.increaseIndent();
1385  p.printNewline();
1386  p << "(";
1387  llvm::interleaveComma(mapper->getArguments(), p,
1388  [&](auto arg) { p.printRegionArgument(arg); });
1389  p << ") ";
1390 
1391  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1392  p.decreaseIndent();
1393  }
1394 }
1395 
1397  auto *bodyBlock = getBody();
1398  auto blockArgs = bodyBlock->getArguments();
1399 
1400  // Checks if the number of `inputs` match the arity of the `mapper` region.
1401  if (getInputs().size() != blockArgs.size())
1402  return emitOpError() << "expects number of operands to match the arity of "
1403  "mapper, but got: "
1404  << getInputs().size() << " and " << blockArgs.size();
1405 
1406  // The parameters of mapper should all match the element type of inputs.
1407  for (const auto &[bbArgType, inputArg] :
1408  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1409  auto inputElemType =
1410  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1411  if (bbArgType != inputElemType) {
1412  return emitOpError() << "expected element type of input " << inputElemType
1413  << " to match bbArg type " << bbArgType;
1414  }
1415  }
1416 
1417  // The shape of each input must match the shape of the output.
1418  auto outputShape = getInit().getType().getShape();
1419  for (Type inputArgType : TypeRange{getInputs()}) {
1420  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1421  if (inputElemShape != outputShape) {
1422  return emitOpError() << "expected shape of input (" << inputElemShape
1423  << ") to match shape of output (" << outputShape
1424  << ")";
1425  }
1426  }
1427 
1428  return success();
1429 }
1430 
1431 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1432  int64_t rank = getInit().getType().getRank();
1433  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1434 }
1435 
1436 ArrayAttr MapOp::getIndexingMaps() {
1437  Builder builder(getContext());
1438  int64_t rank = getInit().getType().getRank();
1439  int64_t numIndexingMaps = getOperands().size();
1441  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1442 }
1443 
1444 void MapOp::getEffects(
1446  &effects) {
1447  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1448  getDpsInits());
1449 }
1450 
1451 //===----------------------------------------------------------------------===//
1452 // ReduceOp
1453 //===----------------------------------------------------------------------===//
1454 
1455 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1456  OpAsmSetValueNameFn setNameFn) {
1457  for (Value v : getRegionInputArgs())
1458  setNameFn(v, "in");
1459  for (Value v : getRegionOutputArgs())
1460  setNameFn(v, "init");
1461 }
1462 
1463 void ReduceOp::getAsmResultNames(
1464  function_ref<void(Value, StringRef)> setNameFn) {
1465  if (!getResults().empty())
1466  setNameFn(getResults().front(), "reduced");
1467 }
1468 
1469 void ReduceOp::build(
1470  OpBuilder &builder, OperationState &result, ValueRange inputs,
1471  ValueRange inits, ArrayRef<int64_t> dimensions,
1472  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1473  ArrayRef<NamedAttribute> attributes) {
1474  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1475  result.addAttributes(attributes);
1476 
1477  // Add output types for `RankedTensorType` output arguments.
1478  for (Value init : inits) {
1479  Type initType = init.getType();
1480  if (llvm::isa<RankedTensorType>(initType))
1481  result.addTypes(initType);
1482  }
1483 
1484  if (bodyBuild)
1485  buildGenericRegion(builder, result.location, *result.regions.front(),
1486  inputs, inits, bodyBuild);
1487 }
1488 
1489 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1490  int64_t inputRank =
1491  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1492  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1493  utils::IteratorType::parallel);
1494  for (int64_t reductionDim : getDimensions())
1495  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1496  return iteratorTypes;
1497 }
1498 
1499 ArrayAttr ReduceOp::getIndexingMaps() {
1500  int64_t inputRank =
1501  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1502  SmallVector<AffineMap> affineMaps(
1503  getNumDpsInputs(),
1505  AffineMap resultMap =
1507  .dropResults(getDimensions());
1508  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1509  affineMaps.push_back(resultMap);
1510  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1511 }
1512 
1513 void ReduceOp::getEffects(
1515  &effects) {
1516  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1517  getDpsInits());
1518 }
1519 
1521  NamedAttrList &attributes,
1522  StringRef attributeName) {
1523  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1524  return failure();
1525 
1526  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1527  return success();
1528 }
1529 
1531  std::optional<OperationName> payloadOpName;
1532  NamedAttrList payloadOpAttrs;
1533  if (succeeded(parser.parseOptionalLBrace())) {
1534  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1535  if (failed(operationName))
1536  return failure();
1537  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1538  return failure();
1539  payloadOpName = operationName.value();
1540  if (parser.parseRBrace())
1541  return failure();
1542  }
1543 
1544  if (parseDstStyleOp(
1545  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1546  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1547  }))
1548  return failure();
1549 
1550  if (payloadOpName.has_value()) {
1551  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1552  ArrayRef(result.operands), /*initFirst=*/true);
1553  } else {
1555  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1556  /*allowType=*/true, /*allowAttrs=*/true)) {
1557  return failure();
1558  }
1559 
1560  Region *body = result.addRegion();
1561  if (parser.parseRegion(*body, regionArgs))
1562  return failure();
1563  }
1564 
1565  return success();
1566 }
1567 
1568 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1569  ArrayRef<int64_t> attributeValue) {
1570  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1571 }
1572 
1573 void ReduceOp::print(OpAsmPrinter &p) {
1574  Block *mapper = getBody();
1575  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1576  if (payloadOp) {
1577  printShortForm(p, payloadOp);
1578  }
1579 
1580  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1581  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1582  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1583  if (!payloadOp) {
1584  // Print region if the payload op was not detected.
1585  p.increaseIndent();
1586  p.printNewline();
1587  p << "(";
1588  llvm::interleaveComma(mapper->getArguments(), p,
1589  [&](auto arg) { p.printRegionArgument(arg); });
1590  p << ") ";
1591 
1592  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1593  p.decreaseIndent();
1594  }
1595 }
1596 
1598  ArrayRef<int64_t> dimensionsRef = getDimensions();
1599 
1600  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1601  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1602  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1603  return emitOpError() << "expects all inputs to have the same shapes. "
1604  "Shape at input-index "
1605  << i
1606  << " is not equal to the shape at input-index 0.";
1607  }
1608  }
1609  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1610  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1611  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1612  return emitOpError() << "expects all outputs to have the same shapes. "
1613  "Shape at output-index "
1614  << i
1615  << " is not equal to the shape at output-index 0.";
1616  }
1617  }
1618  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1619  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1620 
1621  DenseSet<int64_t> dimensionsToReduce;
1622  for (int64_t dimension : dimensionsRef) {
1623  if (dimension < 0 || dimension >= inputType.getRank()) {
1624  return emitOpError()
1625  << "dimensions for reduction should be in the range [0, "
1626  << inputType.getRank() - 1 << "].";
1627  }
1628  dimensionsToReduce.insert(dimension);
1629  }
1630 
1631  auto inputDims = inputType.getShape();
1632  auto initDims = initType.getShape();
1633 
1634  // Input dimensions that will be left after the reduction.
1635  SmallVector<int64_t> reducedInputDims;
1636  for (const auto &en : llvm::enumerate(inputDims)) {
1637  if (!dimensionsToReduce.count(en.index()))
1638  reducedInputDims.push_back(en.value());
1639  }
1640 
1641  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1642  return emitOpError() << "number of dimensions after reduction "
1643  << reducedInputDims.size()
1644  << " doesn't match the init rank "
1645  << initType.getRank();
1646  }
1647 
1648  if (reducedInputDims != initDims)
1649  return emitOpError() << "init dimensions [" << initDims
1650  << "] doesn't match input dimensions after reduction ["
1651  << reducedInputDims << "]";
1652 
1653  Block *block = getBody();
1654  if (block->getNumArguments() != this->getNumOperands())
1655  return emitOpError()
1656  << "mismatching number of operands and block arguments";
1657 
1658  // Check that the first block arguments match the element type of the inputs.
1659  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1660  Type inputElementType =
1661  llvm::cast<ShapedType>(input.getType()).getElementType();
1662  if (inputElementType != bbArg.getType())
1663  return emitOpError()
1664  << "input element type " << inputElementType
1665  << " does not match corresponding block argument type "
1666  << bbArg.getType();
1667  }
1668 
1669  // Check that the last block arguments match the element type of the outputs.
1670  for (auto [output, bbArg] : llvm::zip(
1671  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1672  auto outputElementType =
1673  llvm::cast<ShapedType>(output.getType()).getElementType();
1674  if (outputElementType != bbArg.getType())
1675  return emitOpError()
1676  << "output element type " << outputElementType
1677  << " does not match corresponding block argument type "
1678  << bbArg.getType();
1679  }
1680  return success();
1681 }
1682 
1683 //===----------------------------------------------------------------------===//
1684 // TransposeOp
1685 //===----------------------------------------------------------------------===//
1686 
1687 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1688  Region &region, ValueRange inputs,
1689  ValueRange outputs) {
1690  buildGenericRegion(builder, loc, region, inputs, outputs,
1691  [](OpBuilder &b, Location loc, ValueRange args) {
1692  b.create<linalg::YieldOp>(loc, args[0]);
1693  });
1694 }
1695 
1696 void TransposeOp::build(::mlir::OpBuilder &builder,
1697  ::mlir::OperationState &result, Value input, Value init,
1698  DenseI64ArrayAttr permutation,
1699  ArrayRef<NamedAttribute> attributes) {
1700  result.addOperands(input);
1701  result.addOperands(init);
1702  result.addAttribute(getPermutationAttrName(result.name), permutation);
1703  result.addAttributes(attributes);
1704 
1705  // Add output types for `RankedTensorType` output arguments.
1706  Type initType = init.getType();
1707  if (llvm::isa<RankedTensorType>(initType))
1708  result.addTypes(initType);
1709 
1710  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1711  init);
1712 }
1713 
1714 void TransposeOp::build(::mlir::OpBuilder &builder,
1715  ::mlir::OperationState &result, Value input, Value init,
1716  ArrayRef<int64_t> permutation,
1717  ArrayRef<NamedAttribute> attributes) {
1718  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1719  attributes);
1720 }
1721 
1723  if (failed(parseDstStyleOp(
1724  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1725  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1726  })))
1727  return failure();
1728 
1729  OpBuilder builder(parser.getContext());
1730  buildIdentityRegion(builder, result.location, *result.addRegion(),
1731  /*inputs=*/result.operands,
1732  /*outputs=*/{});
1733  return success();
1734 }
1735 
1736 void TransposeOp::getAsmResultNames(
1737  function_ref<void(Value, StringRef)> setNameFn) {
1738  if (!getResults().empty())
1739  setNameFn(getResults().front(), "transposed");
1740 }
1741 
1743  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1744  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1745  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1746 }
1747 
1749  ArrayRef<int64_t> permutationRef = getPermutation();
1750 
1751  if (!isPermutationVector(permutationRef))
1752  return emitOpError("permutation is not valid");
1753 
1754  auto inputType = getInput().getType();
1755  auto initType = getInit().getType();
1756 
1757  int64_t rank = inputType.getRank();
1758 
1759  if (rank != initType.getRank())
1760  return emitOpError() << "input rank " << rank
1761  << " does not match init rank " << initType.getRank();
1762 
1763  if (rank != static_cast<int64_t>(permutationRef.size()))
1764  return emitOpError() << "size of permutation " << permutationRef.size()
1765  << " does not match the argument rank " << rank;
1766 
1767  auto inputDims = inputType.getShape();
1768  auto initDims = initType.getShape();
1769 
1770  for (int64_t i = 0; i < rank; ++i) {
1771  int64_t inputDim = inputDims[permutationRef[i]];
1772  int64_t initDim = initDims[i];
1773 
1774  if (inputDim != initDim) {
1775  return emitOpError() << "dim(result, " << i << ") = " << initDim
1776  << " doesn't match dim(input, permutation[" << i
1777  << "]) = " << inputDim;
1778  }
1779  }
1780 
1781  return success();
1782 }
1783 
1784 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1785  int64_t rank = getInit().getType().getRank();
1786  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1787 }
1788 
1789 ArrayAttr TransposeOp::getIndexingMaps() {
1790  Builder builder(getContext());
1791  int64_t rank = getInit().getType().getRank();
1792  return builder.getAffineMapArrayAttr(
1794  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1795  builder.getMultiDimIdentityMap(rank)});
1796 }
1797 
1798 void TransposeOp::getEffects(
1800  &effects) {
1801  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1802  getDpsInits());
1803 }
1804 
1805 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1807  // Single dimension transpose.
1808  if (getPermutation().size() == 0) {
1809  result.push_back(getInput());
1810  return success();
1811  }
1812  // Identity permutation.
1813  if (isIdentityPermutation(getPermutation())) {
1814  result.push_back(getInput());
1815  return success();
1816  }
1817 
1818  return failure();
1819 }
1820 
1821 //===----------------------------------------------------------------------===//
1822 // BroadcastOp
1823 //===----------------------------------------------------------------------===//
1824 
1825 void BroadcastOp::build(::mlir::OpBuilder &builder,
1826  ::mlir::OperationState &result, Value input, Value init,
1827  DenseI64ArrayAttr dimensions,
1828  ArrayRef<NamedAttribute> attributes) {
1829  result.addOperands(input);
1830  result.addOperands(init);
1831  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
1832  result.addAttributes(attributes);
1833 
1834  // Add output types for `RankedTensorType` output arguments.
1835  Type initType = init.getType();
1836  if (llvm::isa<RankedTensorType>(initType))
1837  result.addTypes(initType);
1838 
1839  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1840  init);
1841 }
1842 
1843 void BroadcastOp::build(::mlir::OpBuilder &builder,
1844  ::mlir::OperationState &result, Value input, Value init,
1845  ArrayRef<int64_t> dimensions,
1846  ArrayRef<NamedAttribute> attributes) {
1847  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
1848  attributes);
1849 }
1850 
1852  if (failed(parseDstStyleOp(
1853  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1854  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1855  })))
1856  return failure();
1857 
1858  OpBuilder builder(parser.getContext());
1859  buildIdentityRegion(builder, result.location, *result.addRegion(),
1860  /*inputs=*/result.operands,
1861  /*outputs=*/{});
1862  return success();
1863 }
1864 
1865 void BroadcastOp::getAsmResultNames(
1866  function_ref<void(Value, StringRef)> setNameFn) {
1867  if (!getResults().empty())
1868  setNameFn(getResults().front(), "broadcasted");
1869 }
1870 
1872  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1873  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1874  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1875 }
1876 
1878  ArrayRef<int64_t> dimensionsRef = getDimensions();
1879 
1880  auto inputType = getInput().getType();
1881  auto initType = getInit().getType();
1882 
1883  int64_t inputRank = inputType.getRank();
1884  int64_t initRank = initType.getRank();
1885 
1886  auto inputShape = inputType.getShape();
1887  auto initShape = initType.getShape();
1888 
1889  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
1890  return emitOpError() << "input rank plus added dimensions does not "
1891  "match init rank. input rank: "
1892  << inputRank
1893  << ", dimensions size: " << dimensionsRef.size()
1894  << ", init rank: " << initRank;
1895 
1896  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
1897  if (dim < 0 || dim >= initRank)
1898  return emitOpError() << "dimension " << idx
1899  << " is out of range. expected range: [0, "
1900  << initRank - 1 << "], got: " << dim;
1901  }
1902 
1903  // Mapping from input dims to init dims.
1904  SmallVector<int64_t> dimMap;
1905  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
1906  if (!llvm::is_contained(dimensionsRef, dim))
1907  dimMap.push_back(dim);
1908  }
1909 
1910  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
1911  // This dimensions is mapped from the input. Init and input dims should
1912  // match.
1913  if (inputShape[inputDimIdx] != initShape[initDimIdx])
1914  return emitOpError() << "input dim " << inputDimIdx
1915  << " should match init dim " << initDimIdx
1916  << ". input: " << inputShape[inputDimIdx]
1917  << ", init: " << initShape[initDimIdx];
1918  }
1919 
1920  return success();
1921 }
1922 
1923 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
1924  int64_t rank = getInit().getType().getRank();
1925  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1926 }
1927 
1928 ArrayAttr BroadcastOp::getIndexingMaps() {
1929  Builder builder(getContext());
1930  int64_t rank = getInit().getType().getRank();
1931  return builder.getAffineMapArrayAttr(
1932  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
1933  builder.getMultiDimIdentityMap(rank)});
1934 }
1935 
1936 void BroadcastOp::getEffects(
1938  &effects) {
1939  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1940  getDpsInits());
1941 }
1942 
1943 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1944  MLIRContext *context) {
1945  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
1946 }
1947 
1948 //===----------------------------------------------------------------------===//
1949 // YieldOp
1950 //===----------------------------------------------------------------------===//
1951 
1953  if (getNumOperands() > 0)
1954  p << ' ' << getOperands();
1955  p.printOptionalAttrDict((*this)->getAttrs());
1956  if (getNumOperands() > 0)
1957  p << " : " << getOperandTypes();
1958 }
1959 
1962  SmallVector<Type, 2> types;
1963  SMLoc loc = parser.getCurrentLocation();
1964  return failure(parser.parseOperandList(opInfo) ||
1965  parser.parseOptionalAttrDict(result.attributes) ||
1966  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1967  parser.resolveOperands(opInfo, types, loc, result.operands));
1968 }
1969 
1970 // Check the operand number and types must match the element types of the
1971 // LinalgOp interface's shaped operands.
1972 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
1973  if (op.getNumOperands() != linalgOp.getNumDpsInits())
1974  return op.emitOpError("expected number of yield values (")
1975  << op.getNumOperands()
1976  << ") to match the number of inits / outs operands of the enclosing "
1977  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
1978 
1979  for (OpOperand &opOperand : op->getOpOperands()) {
1980  OpOperand *outputOperand =
1981  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
1982  Type elementType = outputOperand->get().getType();
1983  if (isa<MemRefType, RankedTensorType>(elementType))
1984  elementType = getElementTypeOrSelf(outputOperand->get().getType());
1985  if (opOperand.get().getType() != elementType)
1986  return op.emitOpError("type of yield operand ")
1987  << (opOperand.getOperandNumber() + 1) << " ("
1988  << opOperand.get().getType() << ") doesn't match "
1989  << "the element type of the enclosing linalg.generic op ("
1990  << elementType << ")";
1991  }
1992  return success();
1993 }
1994 
1996  auto *parentOp = (*this)->getParentOp();
1997  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1998  return emitOpError("expected single non-empty parent region");
1999 
2000  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2001  return verifyYield(*this, linalgOp);
2002 
2003  return emitOpError("expected parent op with LinalgOp interface");
2004 }
2005 
2006 //===----------------------------------------------------------------------===//
2007 // IndexOp
2008 //===----------------------------------------------------------------------===//
2009 
2011  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2012  if (!linalgOp)
2013  return emitOpError("expected parent op with LinalgOp interface");
2014  if (linalgOp.getNumLoops() <= getDim())
2015  return emitOpError("expected dim (")
2016  << getDim() << ") to be lower than the number of loops ("
2017  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2018  return success();
2019 }
2020 
2021 /////// Operations corresponding to library calls defined with Tablegen ////////
2022 
2023 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2024 
2025 #define GET_OP_CLASSES
2026 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2027 
2028 #define GET_OP_CLASSES
2029 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2030 
2031 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2032  unsigned rank,
2033  MLIRContext *context) {
2034  if (maybeMap)
2035  return *maybeMap;
2036  if (rank == 0)
2037  return AffineMap::get(context);
2038  return AffineMap::getMultiDimIdentityMap(rank, context);
2039 }
2040 
2042 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2043  MLIRContext *context) {
2045  res.reserve(num);
2046  for (unsigned i = 0; i < num; ++i)
2047  res.push_back(getAffineDimExpr(startIdx++, context));
2048  return res;
2049 }
2050 
2053  auto rangeA = llvm::make_range(a.begin(), a.end());
2054  auto rangeB = llvm::make_range(b.begin(), b.end());
2055  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2056  return llvm::to_vector<4>(concatRanges);
2057 }
2058 
2059 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2060  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2061  ss << "view";
2062  for (auto size : memref.getShape())
2063  if (size < 0)
2064  ss << "sx";
2065  else
2066  ss << size << "x";
2067  if (failed(appendMangledType(ss, memref.getElementType())))
2068  return failure();
2069  if (auto as = memref.getMemorySpace()) {
2070  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2071  ss << "as" << attr.getInt();
2072  else
2073  return failure();
2074  }
2075  return success();
2076  }
2077  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2078  ss << "vector";
2079  llvm::interleave(
2080  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2081  if (failed(appendMangledType(ss, vec.getElementType())))
2082  return failure();
2083  return success();
2084  }
2085  if (t.isSignlessIntOrIndexOrFloat()) {
2086  ss << t;
2087  return success();
2088  }
2089  return failure();
2090 }
2091 
2093  assert(isa<LinalgOp>(op));
2094  std::string name(op->getName().getStringRef().str());
2095  std::string fun = "";
2096  for (NamedAttribute kv : op->getAttrs()) {
2097  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2098  fun = stringifyEnum(ufa.getValue()).str() + "_";
2099  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2100  fun = stringifyEnum(bfa.getValue()).str() + "_";
2101  }
2102  }
2103  name.reserve(128);
2104  std::replace(name.begin(), name.end(), '.', '_');
2105  llvm::raw_string_ostream ss(name);
2106  ss << "_" << fun;
2107  for (Type t : op->getOperandTypes()) {
2108  if (failed(appendMangledType(ss, t)))
2109  return std::string();
2110  ss << "_";
2111  }
2112  std::string res = ss.str();
2113  res.pop_back();
2114  return res;
2115 }
2116 
2117 //===----------------------------------------------------------------------===//
2118 // Canonicalizers and Folders.
2119 //===----------------------------------------------------------------------===//
2120 
2121 namespace {
2122 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2124 
2125  LogicalResult matchAndRewrite(LinalgOp op,
2126  PatternRewriter &rewriter) const override {
2127  for (OpOperand &opOperand : op->getOpOperands()) {
2128  // Linalg "inputs" may be either tensor or memref type.
2129  // tensor<0xelt_type> is a convention that may not always mean
2130  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2131  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2132  if (!mt)
2133  continue;
2134  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2135  rewriter.eraseOp(op);
2136  return success();
2137  }
2138  }
2139  return failure();
2140  }
2141 };
2142 
2143 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2144 /// result that is more static than the linalg op.
2145 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2147 
2148  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2149  PatternRewriter &rewriter) const override {
2150  if (!tensor::canFoldIntoProducerOp(castOp))
2151  return failure();
2152 
2153  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2154  if (!linalgOp)
2155  return failure();
2156 
2157  // Cast can be in conditionally reachable region, if which case folding will
2158  // generate invalid code. Only conservatively fold ops in same block for
2159  // now.
2160  if (castOp->getBlock() != linalgOp->getBlock())
2161  return failure();
2162 
2163  OpBuilder::InsertionGuard guard(rewriter);
2164  rewriter.setInsertionPoint(linalgOp);
2165 
2166  Location loc = linalgOp.getLoc();
2167  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2168  unsigned resultNumber = resultValue.getResultNumber();
2169  auto resultType =
2170  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2171  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2172  // going from a more dynamic shape to a less dynamic shape. If the producer
2173  // for this cast, i.e. producer of the out operand, is also an operation
2174  // that folds with tensor.cast consumer (like this pattern), the cast will
2175  // continue to propagate as far up the stack as it can go.
2176  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2177  Value newOperand =
2178  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2179  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2180  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2181  linalgOp.getDpsInits().end());
2182  outputOperands[resultNumber] = newOperand;
2183  newOperands.append(outputOperands.begin(), outputOperands.end());
2184 
2185  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2186  linalgOp->result_type_end());
2187  resultTypes[resultNumber] = resultType;
2188  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2189 
2190  // Create a tensor.cast operation back to the original type.
2191  Value castBack = rewriter.create<tensor::CastOp>(
2192  loc, resultValue.getType(), newOp->getResult(resultNumber));
2193 
2194  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2195  results[resultNumber] = castBack;
2196  rewriter.replaceOp(linalgOp, results);
2197  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2198  return success();
2199  }
2200 };
2201 
2202 /// For each of the operand in `operands` this function maps the static sizes of
2203 /// dimensions to their affine dim expressions.
2204 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2205  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2206  for (OpOperand &opOperand : operands) {
2207  if (linalgOp.isScalar(&opOperand))
2208  continue;
2209  Value src = opOperand.get();
2210  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2211  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2212 
2213  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2214  // `tensor.cast` operation and source of the cast operation has a static
2215  // shape, then assign it to the `sourceShape`.
2216  auto *parentOp = src.getDefiningOp();
2217  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2218  if (parentOp) {
2219  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2220  Value castSource = castOp.getSource();
2221  auto castSourceType =
2222  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2223  if (castSourceType && castSourceType.hasStaticShape())
2224  sourceShape = castSourceType.getShape();
2225  }
2226  }
2227 
2228  // If the source shape's dimension has a static shape, map the affine dim
2229  // expression to the known static size.
2230  for (unsigned i = 0; i < sourceShape.size(); i++) {
2231  if (sourceType.isDynamicDim(i))
2232  continue;
2233  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2234  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2235  }
2236  }
2237 }
2238 
2239 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2240 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2241 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2242 /// change then `changeNeeded` is false and same operand is added in the
2243 /// `newOperands` list.
2244 static void createNewOperandWithStaticSizes(
2245  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2246  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2247  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2248  bool &changeNeeded) {
2249  Value src = opOperand->get();
2250  newOperands.push_back(src);
2251  if (linalgOp.isScalar(opOperand))
2252  return;
2253  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2254  Type resultType = sourceType;
2255  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2256  resultTypes.push_back(resultType);
2257  return;
2258  }
2259  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2260  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2261  SmallVector<int64_t> newShape;
2262  // If operand is updated with new shape, `newOperandNeeded` will be
2263  // true.
2264  bool newOperandNeeded = false;
2265  for (unsigned i = 0; i < sourceShape.size(); i++) {
2266  int64_t dimShape = sourceShape[i];
2267  AffineExpr dimExpr = sourceMap.getResult(i);
2268  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2269  newShape.push_back(dimShape);
2270  continue;
2271  }
2272  // Dimension has a dynamic shape and corresponding affine dim
2273  // expression is present in the map. So assign the size for the
2274  // given affine dim expression to the dimension.
2275  newShape.push_back(affineExprToSize[dimExpr]);
2276  newOperandNeeded = true;
2277  }
2278  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2279  if (newOperandNeeded) {
2280  changeNeeded = true;
2281  // Get the new operand value given its size and element type by
2282  // casting it.
2283  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2284  unsigned index = opOperand->getOperandNumber();
2285  newOperands[index] = newOperand;
2286  }
2287  if (linalgOp.isDpsInit(opOperand))
2288  resultTypes.push_back(resultType);
2289 }
2290 
2291 /// Static shapes for the operands can be inferred if any one of the operands
2292 /// have a static shape. This can be done by referring to the affine dim
2293 /// expressions for the operand.
2294 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2296 
2297  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2298  PatternRewriter &rewriter) const override {
2299  if (!linalgOp.hasPureTensorSemantics())
2300  return failure();
2301 
2302  // Maps must be projected permutations.
2303  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2304  return !map.isProjectedPermutation();
2305  }))
2306  return failure();
2307 
2308  // Maps affine dim expressions to the static size of that dimension.
2309  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2310  Location loc = linalgOp.getLoc();
2311 
2312  // For each of the affine dim expression, check if the size is known. If
2313  // known add that in the map.
2314  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2315 
2316  SmallVector<Value> newOperands;
2317  SmallVector<Type> resultTypes;
2318 
2319  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2320  // change in their types.
2321  bool changeNeeded = false;
2322  newOperands.reserve(linalgOp->getNumOperands());
2323  resultTypes.reserve(linalgOp.getNumDpsInits());
2324 
2325  // Iterate over all the operands and update the static sizes.
2326  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2327  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2328  affineExprToSize, linalgOp, newOperands,
2329  resultTypes, changeNeeded);
2330  }
2331 
2332  // If the generic op has all the required static information, no
2333  // canonicalization needed.
2334  if (!changeNeeded)
2335  return failure();
2336 
2337  // Clone op.
2338  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2339  SmallVector<Value> replacements;
2340  replacements.reserve(newOp->getNumResults());
2341  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2342  Value newResult = std::get<1>(it);
2343  Value oldResult = std::get<0>(it);
2344  Type newType = newResult.getType();
2345  Type oldType = oldResult.getType();
2346  replacements.push_back(
2347  (newType != oldType)
2348  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2349  : newResult);
2350  }
2351  rewriter.replaceOp(linalgOp, replacements);
2352  return success();
2353  }
2354 };
2355 
2356 } // namespace
2357 
2358 // All named ops canonicalizers and folders are auto-generated in the
2359 // .cpp.inc.
2360 
2361 //===----------------------------------------------------------------------===//
2362 // SoftmaxOp
2363 //===----------------------------------------------------------------------===//
2364 
2366  ShapedType inputType = getInputOperandType();
2367  ShapedType outputType = getOutputOperandType();
2368 
2369  ArrayRef<int64_t> inputShape = inputType.getShape();
2370  ArrayRef<int64_t> outputShape = outputType.getShape();
2371  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2372  return emitOpError("incompatible output shape");
2373 
2374  int64_t inputRank = getInputOperandRank();
2375  int64_t dimension = getDimension();
2376  if ((dimension < 0) || (dimension >= inputRank))
2377  return emitOpError("incorrect dimension specified");
2378 
2379  return success();
2380 }
2381 
2382 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2383  int64_t operandRank = getInputOperandRank();
2384  SmallVector<Range> loopBounds(operandRank);
2385  Location loc = getLoc();
2386  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2387  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2388  Value source = getInput();
2389  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2390  loopBounds[dim].offset = zero;
2391  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2392  loopBounds[dim].stride = one;
2393  }
2394  return loopBounds;
2395 }
2396 
2397 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2398  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2399  utils::IteratorType::parallel);
2400  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2401  return iteratorTypes;
2402 }
2403 
2405 SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2406  ArrayRef<OpFoldResult> offsets,
2407  ArrayRef<OpFoldResult> sizes) {
2408  int64_t rank = getInputOperandRank();
2409  auto oneAttr = builder.getI64IntegerAttr(1);
2410  SmallVector<OpFoldResult> strides(rank, oneAttr);
2411  SmallVector<Value> tiledOperands;
2412  tiledOperands.emplace_back(
2413  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
2414  tiledOperands.emplace_back(
2415  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
2416 
2417  SmallVector<Type, 4> resultTypes;
2418  if (hasPureTensorSemantics())
2419  resultTypes.push_back(tiledOperands[1].getType());
2420  Operation *tiledOp =
2421  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2422 
2423  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
2424 }
2425 
2426 LogicalResult SoftmaxOp::getResultTilePosition(
2427  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2428  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2429  SmallVector<OpFoldResult> &resultSizes) {
2430  if (resultNumber == 0) {
2431  resultOffsets.assign(offsets.begin(), offsets.end());
2432  resultSizes.assign(sizes.begin(), sizes.end());
2433  return success();
2434  }
2435  return failure();
2436 }
2437 
2438 // cast(dynamic) -> static.
2439 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2440  return memref::foldMemRefCast(*this);
2441 }
2442 
2445  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2447  Location loc = getOperation()->getLoc();
2448  IRRewriter rewriter(b);
2449  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2450  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2451  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2452  if (!outputShapedType.isDynamicDim(dim)) {
2453  // Static dim: Return IntegerAttr.
2454  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2455  } else {
2456  // Dynamic dim: Return Value.
2457  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2458  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2459  }
2460  }
2461  reifiedReturnShapes.emplace_back(std::move(shapes));
2462  return success();
2463 }
2464 
2465 void SoftmaxOp::getEffects(
2467  &effects) {
2468  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
2469  getDpsInits());
2470 }
2471 
2472 // Helper functions for softmax decomposition.
2473 // @{
2474 
2475 // Helper function to produce the iterator types (reduction or parallel) and
2476 // affine maps for the iterators used in the decomposition of softmax.
2477 // This method creates:
2478 // If allParallel == true:
2479 // - iterator type: {parallel, ..., parallel}
2480 // - affine maps:
2481 // -- identity with inputRank dimensions.
2482 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2483 // where N == inputRank.
2484 //
2485 // If allParallel == false:
2486 // - iterator type at dim(i) == parallel for i != \p dim and
2487 // dim(dim) == reduction.
2488 // - affine map:
2489 // -- identity with inputRank dimensions.
2490 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2491 // where N == inputRank.
2492 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2494  int64_t dim, bool allParallel = false) {
2495  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2496  utils::IteratorType::parallel);
2497  if (!allParallel)
2498  iteratorTypes[dim] = utils::IteratorType::reduction;
2499  MLIRContext *ctxt = builder.getContext();
2500  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2501  SmallVector<AffineExpr, 2> affineExprs;
2502  for (int i = 0; i < inputRank; i++) {
2503  if (i != dim)
2504  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2505  }
2506  auto reductionMap =
2507  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2508  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2509  return std::make_tuple(iteratorTypes, indexingMaps);
2510 }
2511 
2512 // Helper function to produce a linalg.generic that computes a reduction on
2513 // dimension \p dim with the operation type \p T.
2514 template <typename T>
2515 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2516  int64_t dim) {
2517  auto inputType = cast<ShapedType>(input.getType());
2518  ArrayRef<int64_t> inputShape = inputType.getShape();
2519  int64_t inputRank = inputShape.size();
2520  auto [iteratorTypes, indexingMaps] =
2521  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2522  assert(indexingMaps.size() == 2 &&
2523  "We should have two maps: 1 for the input, 1 for the output");
2524  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2525 
2526  auto genericOp = builder.create<linalg::GenericOp>(
2527  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2528  [&](OpBuilder &b, Location loc, ValueRange args) {
2529  Value result = b.create<T>(loc, args[0], args[1]);
2530  b.create<linalg::YieldOp>(loc, result);
2531  });
2532  return genericOp.getResult(0);
2533 }
2534 
2535 /// Produce a linalg generic that computes the second step of the softmax
2536 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2537 /// on dimension \p dim.
2538 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2539  Value max, Value output, int64_t dim) {
2540  auto inputType = cast<ShapedType>(input.getType());
2541  ArrayRef<int64_t> inputShape = inputType.getShape();
2542  int64_t inputRank = inputShape.size();
2543  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2544  builder, inputRank, dim, /*allParallel=*/true);
2545  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2546  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2547  // Add the affine map for the output argument.
2548  indexingMaps.push_back(indexingMaps[0]);
2549  auto genericOp = builder.create<linalg::GenericOp>(
2550  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2551  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2552  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2553  Value result = b.create<math::ExpOp>(loc, diff);
2554  b.create<linalg::YieldOp>(loc, result);
2555  });
2556  return genericOp.getResult(0);
2557 }
2558 
2559 /// Produce a linalg generic that computes the final step of the softmax
2560 /// decomposition.
2561 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2562 /// yield n / d
2563 /// }
2564 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2565  Value denominator, Value output, int64_t dim) {
2566  auto inputType = cast<ShapedType>(numerator.getType());
2567  ArrayRef<int64_t> inputShape = inputType.getShape();
2568  int64_t inputRank = inputShape.size();
2569  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2570  builder, inputRank, dim, /*allParallel=*/true);
2571  assert(indexingMaps.size() == 2 &&
2572  "We should have one map for each input (2)");
2573  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2574  // Add the affine map for the output tensor.
2575  indexingMaps.push_back(indexingMaps[0]);
2576  auto genericOp = builder.create<linalg::GenericOp>(
2577  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2578  indexingMaps, iteratorTypes,
2579  [&](OpBuilder &b, Location loc, ValueRange args) {
2580  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2581  b.create<linalg::YieldOp>(loc, result);
2582  });
2583  return genericOp.getResult(0);
2584 }
2585 // @} End helper functions for softmax decomposition.
2586 
2587 /// Given an N-dimensional tensor x, this method converts
2588 /// softmax(x) to the following sequence of operations:
2589 ///
2590 /// 1. Compute the max of x along dimension d. This results
2591 /// in a N-1 dimensional tensor m.
2592 /// m = max(x, dim = d)
2593 ///
2594 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2595 /// a N dimensional tensor z.
2596 /// z = exp(x - m)
2597 ///
2598 /// 3. Compute the sum of z along dimension d. This results in
2599 /// a N-1 dimensional tensor l.
2600 /// l = sum(z, dim = d)
2601 ///
2602 /// 4. Divide z and l. This gives the N-dimensional softmax.
2603 /// softmax = z / l
2604 ///
2605 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2606  OpBuilder::InsertionGuard guard(b);
2607  b.setInsertionPoint(*this);
2608  Location loc = getLoc();
2609  Value input = getInput();
2610  ShapedType inputType = getInputOperandType();
2611  Type elementType = inputType.getElementType();
2612  int64_t reductionDim = getDimension();
2613  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2614  Value output = getOutput();
2615  dims.erase(dims.begin() + reductionDim);
2616  // Step 1: Compute max along dim.
2617  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2618  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2619  elementType, b, loc,
2620  /*useOnlyFiniteValue=*/true);
2621  Value neutralForMaxFInit =
2622  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2623  .result();
2624  Value max = reduce<arith::MaximumFOp>(b, loc, input, neutralForMaxFInit,
2625  reductionDim);
2626 
2627  // Step 2: Subtract max from input and exponentiate.
2628  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2629 
2630  // Step 3: Compute sum along dim.
2631  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2632  b, loc, /*useOnlyFiniteValue=*/true);
2633  Value zeroInit =
2634  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2635  Value denominator =
2636  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2637 
2638  // Step 4: Compute softmax.
2639  Value result =
2640  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2641  return SmallVector<Value>{result};
2642 }
2643 
2644 //===----------------------------------------------------------------------===//
2645 // LinalgDialect
2646 //===----------------------------------------------------------------------===//
2647 
2648 void LinalgDialect::getCanonicalizationPatterns(
2649  RewritePatternSet &results) const {
2650  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
2651  InferStaticShapeOfOperands>(getContext());
2652 }
2653 
2655  Attribute value, Type type,
2656  Location loc) {
2657  return arith::ConstantOp::materialize(builder, value, type, loc);
2658 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs)
Checks that two types are the same or can be cast into one another.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1687
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:271
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2493
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2564
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:119
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2059
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, ValueRange results, const ValueRange inputOperands, ValueRange outputOperands)
Definition: LinalgOps.cpp:1074
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:259
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:330
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1520
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1568
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2538
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1326
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:155
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1199
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2515
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1972
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:297
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:849
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:290
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:51
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1253
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1355
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:323
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:185
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:292
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:318
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:395
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
OpListType & getOperations()
Definition: Block.h:134
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:465
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:408
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_iterator result_end()
Definition: Operation.h:409
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:107
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:211
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2530
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2051
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:98
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2092
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2031
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2042
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:89
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
Fraction abs(const Fraction &f)
Definition: Fraction.h:104
MPInt ceil(const Fraction &f)
Definition: Fraction.h:76
MPInt floor(const Fraction &f)
Definition: Fraction.h:74
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:345
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:169
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:753
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.