MLIR  20.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Attributes.h"
33 #include "mlir/IR/Matchers.h"
36 #include "mlir/IR/PatternMatch.h"
39 
40 #include "llvm/ADT/DenseMap.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SetOperations.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/ADT/SmallVector.h"
45 #include "llvm/ADT/StringSet.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include "llvm/Support/MathExtras.h"
50 #include "llvm/Support/raw_ostream.h"
51 #include <cassert>
52 #include <optional>
53 
54 using namespace mlir;
55 using namespace mlir::linalg;
56 
57 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
59  int64_t dim) {
60  auto type = cast<ShapedType>(v.getType());
61  if (!type.isDynamicDim(dim))
62  return builder.getIndexAttr(type.getDimSize(dim));
63 
64  return getAsOpFoldResult(
66  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
67  return builder.create<tensor::DimOp>(loc, v, dim);
68  })
69  .Case<MemRefType>([&](MemRefType t) -> Value {
70  return builder.create<memref::DimOp>(loc, v, dim);
71  }));
72 }
73 
74 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
75 /// `source`.
76 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
77  ArrayRef<OpFoldResult> offsets,
79  ArrayRef<OpFoldResult> strides) {
80  return TypeSwitch<Type, Operation *>(source.getType())
81  .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
82  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
83  strides);
84  })
85  .Case<MemRefType>([&](MemRefType type) -> Operation * {
86  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
87  strides);
88  })
89  .Default([&](Type t) -> Operation * { return nullptr; });
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // Helper functions
94 //===----------------------------------------------------------------------===//
95 
97  int64_t dim) {
98  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
99  return b.createOrFold<memref::DimOp>(loc, source, dim);
100  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
101  return b.createOrFold<tensor::DimOp>(loc, source, dim);
102  llvm_unreachable("Expected MemRefType or TensorType");
103 }
104 
106  int64_t dim) {
107  auto shapedType = llvm::cast<ShapedType>(source.getType());
108  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
109  return createOrFoldDimOp(b, loc, source, dim);
110  return b.getIndexAttr(shapedType.getDimSize(dim));
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // Support for named Linalg ops defined in ods-gen.
115 //===----------------------------------------------------------------------===//
116 
119 
120 /// Fills the region of a structured operation using the provided
121 /// `regionBuilder`. The method is used by both named structured ops created by
122 /// ods-gen and by manually defined C++ ops. It is called by both builders and
123 /// parsers and creates a block with arguments corresponding to the elemental
124 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
125 /// ShapedType.
126 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
127  TypeRange inputTypes, TypeRange outputTypes,
129  RegionBuilderFn regionBuilder) {
130  assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
131 
132  SmallVector<Type, 8> argTypes;
133  SmallVector<Location, 8> argLocs;
134  for (auto containers : {inputTypes, outputTypes}) {
135  for (auto t : containers) {
136  argTypes.push_back(
137  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
138 
139  // TODO: Pass in a proper location here.
140  argLocs.push_back(opBuilder.getUnknownLoc());
141  }
142  }
143 
144  // RAII.
145  OpBuilder::InsertionGuard guard(opBuilder);
146  Block *body =
147  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
148 
149  opBuilder.setInsertionPointToStart(body);
150  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
151  regionBuilder(b, *body, attrs);
152 
153  // indexing_maps is an auto-generated method.
154 
155  // iterator_types is an auto-generated method.
156 }
157 
158 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
159 /// The result types are derived automatically if `resultTensorTypes` is none.
160 /// The body of the operation is filled using `regionBuilder`. All ods-gen
161 /// created structured operations use the method to implement their builders.
163  std::optional<TypeRange> resultTensorTypes,
164  ValueRange inputs, ValueRange outputs,
165  ArrayRef<NamedAttribute> attributes,
166  RegionBuilderFn regionBuilder) {
167  // Derive the result types if needed.
168  SmallVector<Type> derivedResultTypes =
169  resultTensorTypes.value_or(TypeRange());
170  if (!resultTensorTypes)
171  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
172  llvm::IsaPred<RankedTensorType>);
173 
174  state.addOperands(inputs);
175  state.addOperands(outputs);
176  state.addTypes(derivedResultTypes);
177 
178  state.addAttributes(attributes);
179  state.addAttribute(
180  "operandSegmentSizes",
181  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182  static_cast<int32_t>(outputs.size())}));
183 
184  // Create and fill the region of the structured operation.
185  Region &region = *state.addRegion();
186  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
187  state.attributes.getAttrs(), regionBuilder);
188 }
189 
190 static void buildMatmulOp(OpBuilder &b, OperationState &state,
191  std::optional<TypeRange> resultTensorTypes,
192  ValueRange inputs, ValueRange outputs,
193  ArrayRef<NamedAttribute> attributes,
194  RegionBuilderFn regionBuilder,
195  ArrayRef<AffineMap> indexingMaps) {
196  // Initialize indexingMaps attribute, for MatmulOp.
197  SmallVector<Attribute, 3> indexingMapsAttrVal;
198  indexingMapsAttrVal = llvm::map_to_vector(
199  MatmulOp::getDefaultIndexingMaps(b.getContext()),
200  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
201  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
202  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
203  attributes, regionBuilder);
204 }
205 
206 /// Common parsing used for both named structured ops created by ods-gen and by
207 /// manually defined C++ ops. Does not handle regions.
208 static ParseResult
210  SmallVectorImpl<Type> &inputTypes,
211  SmallVectorImpl<Type> &outputTypes,
212  bool addOperandSegmentSizes = true) {
213  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
215  outputsOperands;
216 
217  if (succeeded(parser.parseOptionalLess())) {
218  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
219  return failure();
220  }
221  attrsLoc = parser.getCurrentLocation();
222  if (parser.parseOptionalAttrDict(result.attributes))
223  return failure();
224 
225  if (succeeded(parser.parseOptionalKeyword("ins"))) {
226  if (parser.parseLParen())
227  return failure();
228 
229  inputsOperandsLoc = parser.getCurrentLocation();
230  if (parser.parseOperandList(inputsOperands) ||
231  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
232  return failure();
233  }
234 
235  if (succeeded(parser.parseOptionalKeyword("outs"))) {
236  outputsOperandsLoc = parser.getCurrentLocation();
237  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
238  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
239  return failure();
240  }
241 
242  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
243  result.operands) ||
244  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
245  result.operands))
246  return failure();
247 
248  if (addOperandSegmentSizes) {
249  // This is a bit complex because we're trying to be backward compatible with
250  // operation syntax that mix the inherent attributes and the discardable
251  // ones in the same dictionary. If the properties are used, we append the
252  // operandSegmentSizes there directly. Otherwise we append it to the
253  // discardable attributes dictionary where it is handled by the generic
254  // Operation::create(...) method.
255  if (result.propertiesAttr) {
256  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
257  attrs.append("operandSegmentSizes",
259  {static_cast<int32_t>(inputsOperands.size()),
260  static_cast<int32_t>(outputsOperands.size())}));
261  result.propertiesAttr = attrs.getDictionary(parser.getContext());
262  } else {
263  result.addAttribute("operandSegmentSizes",
265  {static_cast<int32_t>(inputsOperands.size()),
266  static_cast<int32_t>(outputsOperands.size())}));
267  }
268  }
269  if (!result.propertiesAttr) {
270  std::optional<RegisteredOperationName> info =
271  result.name.getRegisteredInfo();
272  if (info) {
273  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
274  return parser.emitError(attrsLoc)
275  << "'" << result.name.getStringRef() << "' op ";
276  })))
277  return failure();
278  }
279  }
280  return success();
281 }
282 
284  ValueRange outputs) {
285  if (!inputs.empty())
286  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
287  if (!outputs.empty())
288  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // Specific parsing and printing for named structured ops created by ods-gen.
293 //===----------------------------------------------------------------------===//
294 
295 static ParseResult parseNamedStructuredOpRegion(
296  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
297  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
298  RegionBuilderFn regionBuilder) {
299  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
300  return parser.emitError(
301  parser.getCurrentLocation(),
302  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
303  "region expects {0} args, got {1}",
304  numRegionArgs, inputTypes.size() + outputTypes.size()));
305  }
306 
307  OpBuilder opBuilder(parser.getContext());
308  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
309  regionBuilder);
310  return success();
311 }
312 
313 static ParseResult
315  SmallVectorImpl<Type> &resultTypes) {
316  if (parser.parseOptionalArrowTypeList(resultTypes))
317  return failure();
318  return success();
319 }
320 
321 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
322  OperationState &result,
323  unsigned numRegionArgs,
324  RegionBuilderFn regionBuilder) {
325  // TODO: Enable when ods-gen supports captures.
326  SmallVector<Type, 1> inputTypes, outputTypes;
327  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
328  return failure();
329 
330  // Parse optional attributes.
331  if (parser.parseOptionalAttrDict(result.attributes))
332  return failure();
333 
334  // TODO: consider merging results parsing into region parsing.
335  // Need to wait for declarative assembly resolution to decide.
336  SmallVector<Type, 1> outputTensorsTypes;
337  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
338  return failure();
339  result.addTypes(outputTensorsTypes);
340 
341  std::unique_ptr<Region> region = std::make_unique<Region>();
342  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
343  outputTypes, result.attributes.getAttrs(),
344  regionBuilder))
345  return failure();
346  result.addRegion(std::move(region));
347 
348  return success();
349 }
350 
352  TypeRange resultTypes) {
353  if (resultTypes.empty())
354  return;
355  p.printOptionalArrowTypeList(resultTypes);
356 }
357 
359  ValueRange inputs, ValueRange outputs,
360  ArrayRef<StringRef> elidedAttrs = {}) {
361  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
362 
363  // Printing is shared with generic ops, except for the region and
364  // attributes.
365  printCommonStructuredOpParts(p, inputs, outputs);
366 
367  // Results printing.
369 
370  // Region is elided.
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // Region builder helper.
375 // TODO: Move this to a utility library.
376 // The public methods on this class are referenced directly from generated code.
377 // Helper build the unary, binary, and type conversion functions defined by the
378 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
379 // class.
380 //
381 // Implementations of the math functions must be polymorphic over numeric types,
382 // internally performing necessary casts. If the function application makes no
383 // sense, then the only recourse is to assert and return nullptr. This can be
384 // extended later if it becomes possible to fail construction of the region. The
385 // invariant should be enforced at a higher level.
386 //
387 // TODO: These helpers are currently type polymorphic over the class of integer
388 // and floating point types, but they will not internally cast within bit
389 // widths of a class (mixed precision such as i8->i32) or across classes
390 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
391 // to be handled with care and work is being considered to extend the op
392 // language to make such cases explicit. In the mean-time, violating this will
393 // fail verification, which is deemed acceptable.
394 //===----------------------------------------------------------------------===//
395 
396 namespace {
397 
398 class RegionBuilderHelper {
399 public:
400  RegionBuilderHelper(OpBuilder &builder, Block &block)
401  : builder(builder), block(block) {}
402 
403  // Build the unary functions defined by OpDSL.
404  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
405  if (!isFloatingPoint(arg))
406  llvm_unreachable("unsupported non numeric type");
407  OpBuilder::InsertionGuard g(builder);
408  builder.setInsertionPointToEnd(&block);
409  switch (unaryFn) {
410  case UnaryFn::exp:
411  return builder.create<math::ExpOp>(arg.getLoc(), arg);
412  case UnaryFn::log:
413  return builder.create<math::LogOp>(arg.getLoc(), arg);
414  case UnaryFn::abs:
415  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
416  case UnaryFn::ceil:
417  return builder.create<math::CeilOp>(arg.getLoc(), arg);
418  case UnaryFn::floor:
419  return builder.create<math::FloorOp>(arg.getLoc(), arg);
420  case UnaryFn::negf:
421  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
422  case UnaryFn::reciprocal: {
423  Attribute oneAttr = builder.getOneAttr(arg.getType());
424  auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
425  ::cast<TypedAttr>(oneAttr));
426  return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
427  }
428  case UnaryFn::round:
429  return builder.create<math::RoundOp>(arg.getLoc(), arg);
430  case UnaryFn::sqrt:
431  return builder.create<math::SqrtOp>(arg.getLoc(), arg);
432  case UnaryFn::rsqrt:
433  return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
434  case UnaryFn::square:
435  return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
436  case UnaryFn::tanh:
437  return builder.create<math::TanhOp>(arg.getLoc(), arg);
438  case UnaryFn::erf:
439  return builder.create<math::ErfOp>(arg.getLoc(), arg);
440  }
441  llvm_unreachable("unsupported unary function");
442  }
443 
444  // Build the binary functions defined by OpDSL.
445  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
446  bool allComplex = isComplex(arg0) && isComplex(arg1);
447  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
448  bool allInteger = isInteger(arg0) && isInteger(arg1);
449  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
450  arg1.getType().getIntOrFloatBitWidth() == 1;
451  if (!allComplex && !allFloatingPoint && !allInteger)
452  llvm_unreachable("unsupported non numeric type");
453  OpBuilder::InsertionGuard g(builder);
454  builder.setInsertionPointToEnd(&block);
455  switch (binaryFn) {
456  case BinaryFn::add:
457  if (allComplex)
458  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
459  if (allFloatingPoint)
460  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
461  if (allBool)
462  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
463  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
464  case BinaryFn::sub:
465  if (allComplex)
466  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
467  if (allFloatingPoint)
468  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
469  if (allBool)
470  llvm_unreachable("unsupported operation: sub with bools");
471  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
472  case BinaryFn::mul:
473  if (allComplex)
474  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
475  if (allFloatingPoint)
476  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
477  if (allBool)
478  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
479  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
480  case BinaryFn::div:
481  if (allComplex)
482  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
483  if (allFloatingPoint)
484  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
485  if (allBool)
486  llvm_unreachable("unsupported operation: div with bools");
487  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
488  case BinaryFn::div_unsigned:
489  if (!allInteger || allBool)
490  llvm_unreachable("unsupported operation: unsigned div not on uint");
491  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
492  case BinaryFn::max_signed:
493  assert(!allComplex);
494  if (allFloatingPoint)
495  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
496  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
497  case BinaryFn::min_signed:
498  assert(!allComplex);
499  if (allFloatingPoint)
500  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
501  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
502  case BinaryFn::max_unsigned:
503  assert(!allComplex);
504  if (allFloatingPoint)
505  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
506  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
507  case BinaryFn::min_unsigned:
508  assert(!allComplex);
509  if (allFloatingPoint)
510  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
511  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
512  case BinaryFn::powf:
513  assert(allFloatingPoint);
514  return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
515  }
516  llvm_unreachable("unsupported binary function");
517  }
518 
519  // Build the ternary functions defined by OpDSL.
520  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
521  Value arg2) {
522  bool headBool =
523  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
524  bool tailFloatingPoint =
525  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
526  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
527  OpBuilder::InsertionGuard g(builder);
528  builder.setInsertionPointToEnd(&block);
529  switch (ternaryFn) {
530  case TernaryFn::select:
531  if (!headBool && !(tailFloatingPoint || tailInteger))
532  llvm_unreachable("unsupported non numeric type");
533  return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
534  }
535  llvm_unreachable("unsupported ternary function");
536  }
537 
538  // Build the type functions defined by OpDSL.
539  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
540  switch (typeFn) {
541  case TypeFn::cast_signed:
542  return cast(toType, operand, false);
543  case TypeFn::cast_unsigned:
544  return cast(toType, operand, true);
545  }
546  llvm_unreachable("unsupported type conversion function");
547  }
548 
549  void yieldOutputs(ValueRange values) {
550  OpBuilder::InsertionGuard g(builder);
551  builder.setInsertionPointToEnd(&block);
552  Location loc = builder.getUnknownLoc();
553  builder.create<YieldOp>(loc, values);
554  }
555 
556  Value constant(const std::string &value) {
557  OpBuilder::InsertionGuard g(builder);
558  builder.setInsertionPointToEnd(&block);
559  Location loc = builder.getUnknownLoc();
560  Attribute valueAttr = parseAttribute(value, builder.getContext());
561  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
562  }
563 
564  Value index(int64_t dim) {
565  OpBuilder::InsertionGuard g(builder);
566  builder.setInsertionPointToEnd(&block);
567  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
568  }
569 
570  Type getIntegerType(unsigned width) {
571  return IntegerType::get(builder.getContext(), width);
572  }
573 
574  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
575  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
576 
577 private:
578  // Generates operations to cast the given operand to a specified type.
579  // If the cast cannot be performed, a warning will be issued and the
580  // operand returned as-is (which will presumably yield a verification
581  // issue downstream).
582  Value cast(Type toType, Value operand, bool isUnsignedCast) {
583  OpBuilder::InsertionGuard g(builder);
584  builder.setInsertionPointToEnd(&block);
585  auto loc = operand.getLoc();
586  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
587  }
588 
589  bool isComplex(Value value) {
590  return llvm::isa<ComplexType>(value.getType());
591  }
592  bool isFloatingPoint(Value value) {
593  return llvm::isa<FloatType>(value.getType());
594  }
595  bool isInteger(Value value) {
596  return llvm::isa<IntegerType>(value.getType());
597  }
598 
599  OpBuilder &builder;
600  Block &block;
601 };
602 
603 } // namespace
604 
605 //===----------------------------------------------------------------------===//
606 // CopyOp
607 //===----------------------------------------------------------------------===//
608 
609 namespace {
610 
611 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
613  LogicalResult matchAndRewrite(CopyOp copyOp,
614  PatternRewriter &rewriter) const override {
615  if (copyOp.getInputs() != copyOp.getOutputs())
616  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
617  if (copyOp.hasPureBufferSemantics())
618  rewriter.eraseOp(copyOp);
619  else
620  rewriter.replaceOp(copyOp, copyOp.getInputs());
621 
622  return success();
623  }
624 };
625 
626 } // namespace
627 
628 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
629  MLIRContext *context) {
630  results.add<EraseSelfCopy>(context);
631 }
632 
633 //===----------------------------------------------------------------------===//
634 // FillOp
635 //===----------------------------------------------------------------------===//
636 
637 namespace {
638 
639 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
640 ///
641 /// For such op chains, we can create new linalg.fill ops with the result
642 /// type of the tensor.expand/collapse_shape op.
643 template <typename TensorReshapeOp>
644 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
646  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
647  PatternRewriter &rewriter) const override {
648  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
649  if (!oldFill)
650  return failure();
651 
652  Location loc = oldFill.getLoc();
653  TensorReshapeOp newInit;
654  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
655 
656  newInit = rewriter.create<TensorReshapeOp>(
657  loc, reshapeOp.getResultType(), oldFill.output(),
658  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
659  reshapeOp.getStaticOutputShape());
660  } else {
661  newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
662  oldFill.output(),
663  reshapeOp.getReassociation());
664  }
665  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
666  ValueRange{newInit});
667  return success();
668  }
669 };
670 
671 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
672 /// filling value are the same.
673 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
675 
676  LogicalResult matchAndRewrite(tensor::PadOp padOp,
677  PatternRewriter &rewriter) const override {
678  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
679  if (!fillOp)
680  return failure();
681 
682  // We can only fold if the padding value is the same as the original
683  // filling value.
684  Value padValue = padOp.getConstantPaddingValue();
685  if (!padValue || fillOp.value() != padValue)
686  return failure();
687 
688  ReifiedRankedShapedTypeDims reifiedShape;
689  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
690  return rewriter.notifyMatchFailure(
691  padOp, "failed to reify tensor.pad op result shape");
692 
693  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
694  padOp.getLoc(), reifiedShape.front(),
695  padOp.getResultType().getElementType());
696  Value replacement =
697  rewriter
698  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
699  ValueRange{emptyTensor})
700  .getResult(0);
701  if (replacement.getType() != padOp.getResultType()) {
702  replacement = rewriter.create<tensor::CastOp>(
703  fillOp.getLoc(), padOp.getResultType(), replacement);
704  }
705  rewriter.replaceOp(padOp, replacement);
706  return success();
707  }
708 };
709 
710 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
711 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
712 /// filling value are the same.
713 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
715 
716  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
717  PatternRewriter &rewriter) const override {
718  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
719  if (!srcPadOp)
720  return failure();
721 
722  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
723  return failure();
724 
725  // Walk back the tensor.insert_slice chain and find the first destination
726  // value at the start of the chain.
727  Value firstDest = insertOp.getDest();
728  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
729  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
730  return failure();
731 
732  // Make sure the range of values accessed are disjoint. Without this, we
733  // cannot fold tensor.pad away.
734  bool disjoint = false;
735  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
736  // If the dimension has dynamic offset/size, we cannot guarantee
737  // disjoint. So just skip it.
738  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
739  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
740  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
741  continue;
742 
743  // Get the range start and end, inclusively for both.
744  int64_t prevStart = prevOp.getStaticOffset(i);
745  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
746  prevOp.getStaticStride(i);
747  int64_t nextStart = insertOp.getStaticOffset(i);
748  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
749  insertOp.getStaticStride(i);
750  if (prevEnd < nextStart || nextEnd < prevStart) {
751  disjoint = true;
752  break;
753  }
754  }
755 
756  if (!disjoint)
757  break;
758  firstDest = prevOp.getDest();
759  }
760 
761  // Check whether the first destination is a fill op. For overlapped cases,
762  // this also cannot be true.
763  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
764  if (!dstFillOp)
765  return failure();
766 
767  // We can only fold if the padding value is the same as the original
768  // filling value.
769  Value padValue = srcPadOp.getConstantPaddingValue();
770  if (!padValue || dstFillOp.value() != padValue)
771  return failure();
772 
773  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
774  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
775 
776  Location loc = insertOp.getLoc();
777  MLIRContext *context = getContext();
778 
779  AffineExpr sym0, sym1;
780  bindSymbols(context, sym0, sym1);
781  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
782 
783  // Calculate the new offsets for the insert. It should be the old offsets
784  // plus low padding sizes.
785  SmallVector<OpFoldResult, 4> newOffsets;
786  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
787  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
788  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
789  }
790 
791  RankedTensorType srcPadType = srcPadOp.getSourceType();
793  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
794  if (srcPadType.isDynamicDim(i)) {
795  newSizes.push_back(
796  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
797  .getResult());
798  } else {
799  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
800  }
801  }
802 
803  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
804  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
805  newSizes, insertOp.getMixedStrides());
806  return success();
807  }
808 };
809 
810 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
811 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
812 public:
814 
815  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
816  PatternRewriter &rewriter) const override {
817  // See if tensor input of tensor.extract op is the result of a linalg.fill
818  // op.
819  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
820  if (!fillOp)
821  return failure();
822 
823  // Get scalar input operand of linalg.fill op.
824  Value extractedScalar = fillOp.getInputs()[0];
825 
826  // Replace tensor.extract op with scalar value used to fill the tensor.
827  rewriter.replaceOp(extractOp, extractedScalar);
828  return success();
829  }
830 };
831 
832 /// Folds pack(fill) into a single fill op if
833 /// 1. The pack op does not have padding value, or
834 /// 2. The filled value and padding value are the same.
835 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
836  tensor::PackOp packOp) {
837  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
838  if (!fillOp)
839  return failure();
840 
841  if (auto paddingValue = packOp.getPaddingValue())
842  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
843  return failure();
844 
845  Value packOpDest = packOp.getDest();
846  if (!packOpDest.hasOneUse())
847  return failure();
848 
849  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
850  packOp.getDest());
851 }
852 
853 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
854 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
855 public:
856  FoldFillWithPack(MLIRContext *context)
857  : OpRewritePattern<tensor::PackOp>(context) {}
858 
859  LogicalResult matchAndRewrite(tensor::PackOp packOp,
860  PatternRewriter &rewriter) const override {
861  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
862  if (failed(fillOp))
863  return failure();
864  rewriter.replaceOp(packOp, fillOp.value().result());
865  return success();
866  }
867 };
868 
869 /// Fold fill with copy.
870 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
872 
873  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
874  PatternRewriter &rewriter) const override {
875  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
876  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
877  fillOp.getInputs(),
878  copyOp.getOutputs());
879  return success();
880  }
881  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
882  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
883  fillOp.getOutputs());
884  return success();
885  }
886  return failure();
887  }
888 };
889 
890 /// Fold fill with transpose.
891 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
893 
894  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
895  PatternRewriter &rewriter) const override {
896  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
897  rewriter.replaceOpWithNewOp<FillOp>(
898  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
899  transposeOp.getDpsInitOperand(0)->get());
900  return success();
901  }
902  return failure();
903  }
904 };
905 
906 /// Fold a concat with all elements being fills of the same value
907 /// into a fill of the concat result shape.
908 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
910 
911  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
912  PatternRewriter &rewriter) const override {
913  auto concatOperands = concatOp.getInputs();
914  if (concatOperands.empty()) {
915  return failure();
916  }
917 
918  auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
919  if (!firstFillOp) {
920  return failure();
921  }
922  // Prefetch the fill value.
923  OpFoldResult firstFillVal =
924  getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
925  // Collect all the outs values for the fill operations.
926  SmallVector<Value> allOuts;
927  allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
928 
929  auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
930  auto fillOp = v.getDefiningOp<linalg::FillOp>();
931  if (!fillOp) {
932  return false;
933  }
934 
935  OpFoldResult fillVal =
936  getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
937  if (fillVal != firstFillVal)
938  return false;
939 
940  allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
941  return true;
942  };
943  if (!llvm::all_of(concatOperands.drop_front(),
944  isDefinedByCompatibleFillOp)) {
945  return rewriter.notifyMatchFailure(
946  concatOp, "not all operands are defined by a compatible fill op");
947  }
948 
949  Value outsConcat = rewriter.create<tensor::ConcatOp>(
950  concatOp.getLoc(), concatOp.getDim(), allOuts);
951  rewriter.replaceOpWithNewOp<linalg::FillOp>(
952  concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
953  return success();
954  }
955 };
956 
957 } // namespace
958 
959 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
960  MLIRContext *context) {
961  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
962  FoldFillWithPack, FoldFillWithPad,
963  FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
964  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
965  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
966 }
967 
968 //===----------------------------------------------------------------------===//
969 // GenericOp
970 //===----------------------------------------------------------------------===//
971 
972 static void buildGenericRegion(
973  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
974  ValueRange outputs,
975  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
976  SmallVector<Type, 4> blockArgTypes;
977  SmallVector<Location, 4> blockArgLocs;
978  for (ValueRange container : {inputs, outputs}) {
979  for (Value v : container) {
980  Type t = v.getType();
981  blockArgTypes.push_back(
982  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
983  blockArgLocs.push_back(v.getLoc());
984  }
985  }
986 
987  OpBuilder::InsertionGuard guard(builder);
988  Block *bodyBlock =
989  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
990  bodyBuild(builder, loc, bodyBlock->getArguments());
991 }
992 
993 void GenericOp::getAsmBlockArgumentNames(Region &region,
994  OpAsmSetValueNameFn setNameFn) {
995  for (Value v : getRegionInputArgs())
996  setNameFn(v, "in");
997  for (Value v : getRegionOutputArgs())
998  setNameFn(v, "out");
999 }
1000 
1001 void GenericOp::build(
1002  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1003  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1004  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1005  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1006  ArrayRef<NamedAttribute> attributes) {
1007  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1008  iteratorTypes, doc, libraryCall);
1009  result.addAttributes(attributes);
1010  if (bodyBuild)
1011  buildGenericRegion(builder, result.location, *result.regions.front(),
1012  inputs, outputs, bodyBuild);
1013 }
1014 
1015 void GenericOp::build(
1016  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1017  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1018  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1019  StringRef libraryCall,
1020  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1021  ArrayRef<NamedAttribute> attributes) {
1022  build(builder, result, resultTensorTypes, inputs, outputs,
1023  builder.getAffineMapArrayAttr(indexingMaps),
1024  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1025  iteratorTypes,
1026  [&](utils::IteratorType iter) -> mlir::Attribute {
1027  return IteratorTypeAttr::get(builder.getContext(), iter);
1028  }))),
1029  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1030  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1031  bodyBuild, attributes);
1032 }
1033 
1034 void GenericOp::build(
1035  OpBuilder &builder, OperationState &result, ValueRange inputs,
1036  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1037  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1038  StringRef libraryCall,
1039  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1040  ArrayRef<NamedAttribute> attributes) {
1041  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1042  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1043 }
1044 
1045 void GenericOp::build(
1046  OpBuilder &builder, OperationState &result, ValueRange inputs,
1047  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1048  ArrayRef<utils::IteratorType> iteratorTypes,
1049  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1050  ArrayRef<NamedAttribute> attributes) {
1051  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1052  /*doc=*/"",
1053  /*libraryCall=*/"", bodyBuild, attributes);
1054 }
1055 
1056 void GenericOp::build(
1057  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1058  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1059  ArrayRef<utils::IteratorType> iteratorTypes,
1060  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1061  ArrayRef<NamedAttribute> attributes) {
1062  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1063  iteratorTypes,
1064  /*doc=*/"",
1065  /*libraryCall=*/"", bodyBuild, attributes);
1066 }
1067 
1068 void GenericOp::print(OpAsmPrinter &p) {
1069  p << " ";
1070 
1071  // Print extra attributes.
1072  auto genericAttrNames = linalgTraitAttrNames();
1073 
1074  llvm::StringSet<> genericAttrNamesSet;
1075  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1076  SmallVector<NamedAttribute, 8> genericAttrs;
1077  for (auto attr : (*this)->getAttrs()) {
1078  if (attr.getName() == getIteratorTypesAttrName()) {
1079  auto iteratorTypes =
1080  llvm::cast<ArrayAttr>(attr.getValue())
1081  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1082  // Convert IteratorType enums into the string representation. This is
1083  // needed, because tests still use the old format when 'iterator_types'
1084  // attribute is represented as an array of strings.
1085  // TODO: Remove this conversion once tests are fixed.
1086  SmallVector<Attribute> iteratorTypeNames =
1087  llvm::to_vector(llvm::map_range(
1088  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1089  return StringAttr::get(getContext(), stringifyIteratorType(t));
1090  }));
1091 
1092  genericAttrs.emplace_back(
1093  getIteratorTypesAttrName(),
1094  ArrayAttr::get(getContext(), iteratorTypeNames));
1095  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1096  genericAttrs.push_back(attr);
1097  }
1098  }
1099  if (!genericAttrs.empty()) {
1100  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1101  p << genericDictAttr;
1102  }
1103 
1104  // Printing is shared with named ops, except for the region and attributes
1105  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1106 
1107  genericAttrNames.push_back("operandSegmentSizes");
1108  genericAttrNamesSet.insert(genericAttrNames.back());
1109 
1110  bool hasExtraAttrs = false;
1111  for (NamedAttribute n : (*this)->getAttrs()) {
1112  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1113  break;
1114  }
1115  if (hasExtraAttrs) {
1116  p << " attrs = ";
1117  p.printOptionalAttrDict((*this)->getAttrs(),
1118  /*elidedAttrs=*/genericAttrNames);
1119  }
1120 
1121  // Print region.
1122  if (!getRegion().empty()) {
1123  p << ' ';
1124  p.printRegion(getRegion());
1125  }
1126 
1127  // Print results.
1128  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1129 }
1130 
1131 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1132  DictionaryAttr dictAttr;
1133  // Parse the core linalg traits that must check into a dictAttr.
1134  // The name is unimportant as we will overwrite result.attributes.
1135  // The core linalg traits must contain the information necessary to pass the
1136  // verifier.
1137  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1138  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1139  return failure();
1140  result.attributes.assign(dictAttr.getValue().begin(),
1141  dictAttr.getValue().end());
1142 
1143  // Convert array of string into an array of IteratorType enums. This is
1144  // needed, because tests still use the old format when 'iterator_types'
1145  // attribute is represented as an array of strings.
1146  // TODO: Remove this conversion once tests are fixed.
1147  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1148  result.attributes.get(getIteratorTypesAttrName(result.name)));
1149  if (!iteratorTypes) {
1150  return parser.emitError(attributeLocation)
1151  << "expected " << getIteratorTypesAttrName(result.name)
1152  << " array attribute";
1153  }
1154 
1155  SmallVector<Attribute> iteratorTypeAttrs;
1156 
1157  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1158  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1159  if (!maybeIteratorType.has_value())
1160  return parser.emitError(parser.getCurrentLocation())
1161  << "unexpected iterator_type (" << s << ")";
1162 
1163  iteratorTypeAttrs.push_back(
1164  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1165  }
1166  result.attributes.set(getIteratorTypesAttrName(result.name),
1167  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1168 
1169  // Parsing is shared with named ops, except for the region.
1170  SmallVector<Type, 1> inputTypes, outputTypes;
1171  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1172  return failure();
1173 
1174  // Optional attributes may be added.
1175  if (succeeded(parser.parseOptionalKeyword("attrs")))
1176  if (failed(parser.parseEqual()) ||
1177  failed(parser.parseOptionalAttrDict(result.attributes)))
1178  return failure();
1179 
1180  std::unique_ptr<Region> region = std::make_unique<Region>();
1181  if (parser.parseRegion(*region, {}))
1182  return failure();
1183  result.addRegion(std::move(region));
1184 
1185  // Generic ops may specify that a subset of its outputs are tensors. Such
1186  // outputs are specified in the result type.
1187  // TODO: may need to move output parsing before region parsing.
1188  // Need to wait for declarative assembly resolution to decide.
1189  SmallVector<Type, 1> outputTensorsTypes;
1190  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1191  return failure();
1192  result.addTypes(outputTensorsTypes);
1193 
1194  return success();
1195 }
1196 
1199  &effects,
1200  LinalgOp linalgOp) {
1201  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1202  if (!llvm::isa<MemRefType>(operand.getType()))
1203  continue;
1204  effects.emplace_back(
1205  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1206  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1207  }
1208 
1209  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1210  if (!llvm::isa<MemRefType>(operand.get().getType()))
1211  continue;
1212  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1213  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1214  /*effectOnFullRegion=*/true,
1216  }
1217  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1218  /*effectOnFullRegion=*/true,
1220  }
1221 }
1222 
1223 void GenericOp::getEffects(
1225  &effects) {
1226  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1227 }
1228 
1230 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1231  // Operands with value semantics are speculatable, while operands with memory
1232  // semantics are not.
1233  if (!linalgOp.hasPureTensorSemantics())
1235  // The body of the op can still have speculation in its region.
1237 }
1238 
1239 Speculation::Speculatability GenericOp::getSpeculatability() {
1240  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1241 }
1242 
1243 LogicalResult GenericOp::verify() { return success(); }
1244 
1245 namespace {
1246 
1247 /// Remove any linalg operation (on tensors) that are just copying
1248 /// the values from inputs to the results. Requirements are
1249 /// 1) All iterator types are parallel
1250 /// 2) The body contains just a yield operation with the yielded values being
1251 /// the arguments corresponding to the operands.
1252 template <typename OpTy>
1253 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1255 
1256  LogicalResult matchAndRewrite(OpTy linalgOp,
1257  PatternRewriter &rewriter) const override {
1258  // All indexing maps must be equal. It follows that they are permutations.
1259  if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1260  return failure();
1261 
1262  // Check that the body of the linalg operation is just a linalg.yield
1263  // operation.
1264  Block &body = linalgOp->getRegion(0).front();
1265  if (!llvm::hasSingleElement(body))
1266  return failure();
1267  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1268  if (!yieldOp)
1269  return failure();
1270 
1271  // In the buffer case, we need to check exact buffer equality.
1272  if (linalgOp.hasPureBufferSemantics()) {
1273  if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1274  linalgOp.getDpsInputOperand(0)->get() ==
1275  linalgOp.getDpsInitOperand(0)->get()) {
1276  rewriter.eraseOp(linalgOp);
1277  return success();
1278  }
1279  return failure();
1280  }
1281 
1282  // Mixed semantics is not supported yet.
1283  if (!linalgOp.hasPureTensorSemantics())
1284  return failure();
1285 
1286  // Get the argument number of the returned values. That is the operand
1287  // number to use for replacing uses of this operation.
1288  SmallVector<Value> returnedArgs;
1289  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1290  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1291  if (!yieldArg || yieldArg.getOwner() != &body)
1292  return failure();
1293  unsigned argumentNumber = yieldArg.getArgNumber();
1294  Value returnedArg = linalgOp->getOperand(argumentNumber);
1295  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1296  // The input can have a different type than the result, e.g. a dynamic
1297  // input dimension can be turned into a static output dimension.
1298  Type returnType = returnedArg.getType();
1299  if (returnType != resultType) {
1300  // Distinguish between sparse conversion or dense tensor casting.
1301  // TODO: unify the two ops?
1302  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1304  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1305  linalgOp.getLoc(), resultType, returnedArg);
1306  else {
1307  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1308  resultType))
1309  return failure();
1310  returnedArg = rewriter.create<tensor::CastOp>(
1311  linalgOp.getLoc(), resultType, returnedArg);
1312  }
1313  }
1314  returnedArgs.push_back(returnedArg);
1315  }
1316 
1317  if (returnedArgs.size() != linalgOp->getNumResults())
1318  return failure();
1319  rewriter.replaceOp(linalgOp, returnedArgs);
1320  return success();
1321  }
1322 };
1323 
1324 } // namespace
1325 
1326 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1327  MLIRContext *context) {
1328  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1329 }
1330 
1331 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1332  return memref::foldMemRefCast(*this);
1333 }
1334 
1335 //===----------------------------------------------------------------------===//
1336 // MapOp
1337 //===----------------------------------------------------------------------===//
1338 
1339 static ParseResult parseDstStyleOp(
1340  OpAsmParser &parser, OperationState &result,
1341  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1342  nullptr) {
1343  // Parse `ins` and `outs`.
1344  SmallVector<Type, 4> inputTypes, outputTypes;
1345  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1346  /*addOperandSegmentSizes=*/false))
1347  return failure();
1348 
1349  // Add result types.
1350  for (Type outputType : outputTypes) {
1351  if (llvm::isa<RankedTensorType>(outputType))
1352  result.addTypes(outputType);
1353  }
1354 
1355  // Parse required attributes.
1356  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1357  return failure();
1358 
1359  // Parse optional attributes.
1360  if (parser.parseOptionalAttrDict(result.attributes))
1361  return failure();
1362  return success();
1363 }
1364 
1365 void MapOp::getAsmBlockArgumentNames(Region &region,
1366  OpAsmSetValueNameFn setNameFn) {
1367  for (Value v : getRegionInputArgs())
1368  setNameFn(v, "in");
1369 }
1370 
1371 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1372  if (!getResults().empty())
1373  setNameFn(getResults().front(), "mapped");
1374 }
1375 
1376 void MapOp::build(
1377  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1378  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1379  ArrayRef<NamedAttribute> attributes) {
1380  build(builder, result, TypeRange{}, inputs, init);
1381  result.addAttributes(attributes);
1382 
1383  // Add output types for `RankedTensorType` output arguments.
1384  Type initType = init.getType();
1385  if (llvm::isa<RankedTensorType>(initType))
1386  result.addTypes(initType);
1387 
1388  if (bodyBuild)
1389  buildGenericRegion(builder, result.location, *result.regions.front(),
1390  inputs, /*outputs=*/{}, bodyBuild);
1391 }
1392 
1393 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1394  const OperationName &payloadOpName,
1395  const NamedAttrList &payloadOpAttrs,
1396  ArrayRef<Value> operands,
1397  bool initFirst = false) {
1398  OpBuilder b(parser.getContext());
1399  Region *body = result.addRegion();
1400  Block &block = body->emplaceBlock();
1401  b.setInsertionPointToStart(&block);
1402  SmallVector<Value> bbArgs;
1403  for (auto &operand : operands) {
1404  block.addArgument(
1405  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1406  b.getUnknownLoc());
1407  }
1408  SmallVector<Value> payloadOpOperands;
1409  // If initFirst flag is enabled, we consider init as the first position of
1410  // payload operands.
1411  if (initFirst) {
1412  payloadOpOperands.push_back(block.getArguments().back());
1413  for (const auto &arg : block.getArguments().drop_back())
1414  payloadOpOperands.push_back(arg);
1415  } else {
1416  payloadOpOperands = {block.getArguments().begin(),
1417  block.getArguments().end()};
1418  }
1419 
1420  Operation *payloadOp = b.create(
1421  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1422  payloadOpOperands,
1423  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1424  .getElementType()},
1425  payloadOpAttrs);
1426  b.create<YieldOp>(result.location, payloadOp->getResults());
1427 }
1428 
1429 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1430  std::optional<OperationName> payloadOpName;
1431  NamedAttrList payloadOpAttrs;
1432  if (succeeded(parser.parseOptionalLBrace())) {
1433  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1434  if (failed(operationName))
1435  return failure();
1436  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1437  return failure();
1438  payloadOpName = operationName.value();
1439  if (parser.parseRBrace())
1440  return failure();
1441  }
1442 
1443  if (parseDstStyleOp(parser, result))
1444  return failure();
1445 
1446  if (payloadOpName.has_value()) {
1447  if (!result.operands.empty())
1448  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1449  payloadOpAttrs,
1450  ArrayRef(result.operands).drop_back());
1451  else
1452  result.addRegion();
1453  } else {
1455  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1456  /*allowType=*/true, /*allowAttrs=*/true)) {
1457  return failure();
1458  }
1459  Region *body = result.addRegion();
1460  if (parser.parseRegion(*body, regionArgs))
1461  return failure();
1462  }
1463  return success();
1464 }
1465 
1466 // Retrieve the operation from the body, if it is the only one (except
1467 // yield) and if it gets the same amount of arguments as the body does.
1468 // If initFirst flag is enabled, we check that init takes the first position in
1469 // operands of payload.
1470 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1471  if (body->getOperations().size() != 2)
1472  return nullptr;
1473  Operation &payload = body->getOperations().front();
1474  assert(isa<YieldOp>(body->getOperations().back()));
1475 
1476  if (payload.getNumOperands() == 0 ||
1477  payload.getNumOperands() != body->getNumArguments())
1478  return nullptr;
1479  if (initFirst) {
1480  // check init
1481  if (payload.getOperands().back() != body->getArgument(0))
1482  return nullptr;
1483  // check rest
1484  for (const auto &[operand, bbArg] :
1485  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1486  if (bbArg != operand)
1487  return nullptr;
1488  }
1489  } else {
1490  for (const auto &[operand, bbArg] :
1491  llvm::zip(payload.getOperands(), body->getArguments())) {
1492  if (bbArg != operand)
1493  return nullptr;
1494  }
1495  }
1496  return &payload;
1497 }
1498 
1499 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1500  SmallVector<StringRef> elidedAttrs;
1501  p << " { " << payloadOp->getName().getStringRef();
1502  for (const auto &attr : payloadOp->getAttrs()) {
1503  if (auto fastAttr = dyn_cast<arith::FastMathFlagsAttr>(attr.getValue())) {
1504  if (fastAttr.getValue() == arith::FastMathFlags::none) {
1505  elidedAttrs.push_back(attr.getName());
1506  }
1507  }
1508  if (auto denormAttr = dyn_cast<arith::DenormalModeAttr>(attr.getValue())) {
1509  if (denormAttr.getValue() == arith::DenormalMode::ieee) {
1510  elidedAttrs.push_back(attr.getName());
1511  }
1512  }
1513  }
1514  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1515  p << " }";
1516 }
1517 
1518 void MapOp::print(OpAsmPrinter &p) {
1519  Block *mapper = getBody();
1520  Operation *payloadOp = findPayloadOp(mapper);
1521  if (payloadOp) {
1522  printShortForm(p, payloadOp);
1523  }
1524 
1525  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1526  p.printOptionalAttrDict((*this)->getAttrs());
1527 
1528  if (!payloadOp) {
1529  // Print region if the payload op was not detected.
1530  p.increaseIndent();
1531  p.printNewline();
1532  p << "(";
1533  llvm::interleaveComma(mapper->getArguments(), p,
1534  [&](auto arg) { p.printRegionArgument(arg); });
1535  p << ") ";
1536 
1537  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1538  p.decreaseIndent();
1539  }
1540 }
1541 
1542 LogicalResult MapOp::verify() {
1543  auto *bodyBlock = getBody();
1544  auto blockArgs = bodyBlock->getArguments();
1545 
1546  // Checks if the number of `inputs` match the arity of the `mapper` region.
1547  if (getInputs().size() != blockArgs.size())
1548  return emitOpError() << "expects number of operands to match the arity of "
1549  "mapper, but got: "
1550  << getInputs().size() << " and " << blockArgs.size();
1551 
1552  // The parameters of mapper should all match the element type of inputs.
1553  for (const auto &[bbArgType, inputArg] :
1554  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1555  auto inputElemType =
1556  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1557  if (bbArgType != inputElemType) {
1558  return emitOpError() << "expected element type of input " << inputElemType
1559  << " to match bbArg type " << bbArgType;
1560  }
1561  }
1562 
1563  // The shape of each input must match the shape of the output.
1564  auto outputShape = getInit().getType().getShape();
1565  for (Type inputArgType : TypeRange{getInputs()}) {
1566  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1567  if (inputElemShape != outputShape) {
1568  return emitOpError() << "expected shape of input (" << inputElemShape
1569  << ") to match shape of output (" << outputShape
1570  << ")";
1571  }
1572  }
1573 
1574  return success();
1575 }
1576 
1577 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1578  int64_t rank = getInit().getType().getRank();
1579  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1580 }
1581 
1582 ArrayAttr MapOp::getIndexingMaps() {
1583  Builder builder(getContext());
1584  int64_t rank = getInit().getType().getRank();
1585  int64_t numIndexingMaps = getOperands().size();
1587  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1588 }
1589 
1590 void MapOp::getEffects(
1592  &effects) {
1593  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1594 }
1595 
1596 Speculation::Speculatability MapOp::getSpeculatability() {
1597  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1598 }
1599 
1600 //===----------------------------------------------------------------------===//
1601 // ReduceOp
1602 //===----------------------------------------------------------------------===//
1603 
1604 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1605  OpAsmSetValueNameFn setNameFn) {
1606  for (Value v : getRegionInputArgs())
1607  setNameFn(v, "in");
1608  for (Value v : getRegionOutputArgs())
1609  setNameFn(v, "init");
1610 }
1611 
1612 void ReduceOp::getAsmResultNames(
1613  function_ref<void(Value, StringRef)> setNameFn) {
1614  if (!getResults().empty())
1615  setNameFn(getResults().front(), "reduced");
1616 }
1617 
1618 void ReduceOp::build(
1619  OpBuilder &builder, OperationState &result, ValueRange inputs,
1620  ValueRange inits, ArrayRef<int64_t> dimensions,
1621  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1622  ArrayRef<NamedAttribute> attributes) {
1623  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1624  result.addAttributes(attributes);
1625 
1626  // Add output types for `RankedTensorType` output arguments.
1627  for (Value init : inits) {
1628  Type initType = init.getType();
1629  if (llvm::isa<RankedTensorType>(initType))
1630  result.addTypes(initType);
1631  }
1632 
1633  if (bodyBuild)
1634  buildGenericRegion(builder, result.location, *result.regions.front(),
1635  inputs, inits, bodyBuild);
1636 }
1637 
1638 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1639  int64_t inputRank =
1640  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1641  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1642  utils::IteratorType::parallel);
1643  for (int64_t reductionDim : getDimensions())
1644  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1645  return iteratorTypes;
1646 }
1647 
1648 ArrayAttr ReduceOp::getIndexingMaps() {
1649  int64_t inputRank =
1650  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1651  SmallVector<AffineMap> affineMaps(
1652  getNumDpsInputs(),
1654  AffineMap resultMap =
1656  .dropResults(getDimensions());
1657  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1658  affineMaps.push_back(resultMap);
1659  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1660 }
1661 
1662 void ReduceOp::getEffects(
1664  &effects) {
1665  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1666 }
1667 
1668 Speculation::Speculatability ReduceOp::getSpeculatability() {
1669  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1670 }
1671 
1672 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1673  NamedAttrList &attributes,
1674  StringRef attributeName) {
1675  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1676  return failure();
1677 
1678  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1679  return success();
1680 }
1681 
1682 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1683  std::optional<OperationName> payloadOpName;
1684  NamedAttrList payloadOpAttrs;
1685  if (succeeded(parser.parseOptionalLBrace())) {
1686  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1687  if (failed(operationName))
1688  return failure();
1689  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1690  return failure();
1691  payloadOpName = operationName.value();
1692  if (parser.parseRBrace())
1693  return failure();
1694  }
1695 
1696  if (parseDstStyleOp(
1697  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1698  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1699  }))
1700  return failure();
1701 
1702  if (payloadOpName.has_value()) {
1703  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1704  ArrayRef(result.operands), /*initFirst=*/true);
1705  } else {
1707  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1708  /*allowType=*/true, /*allowAttrs=*/true)) {
1709  return failure();
1710  }
1711 
1712  Region *body = result.addRegion();
1713  if (parser.parseRegion(*body, regionArgs))
1714  return failure();
1715  }
1716 
1717  return success();
1718 }
1719 
1720 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1721  ArrayRef<int64_t> attributeValue) {
1722  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1723 }
1724 
1725 void ReduceOp::print(OpAsmPrinter &p) {
1726  Block *mapper = getBody();
1727  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1728  if (payloadOp) {
1729  printShortForm(p, payloadOp);
1730  }
1731 
1732  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1733  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1734  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1735  if (!payloadOp) {
1736  // Print region if the payload op was not detected.
1737  p.increaseIndent();
1738  p.printNewline();
1739  p << "(";
1740  llvm::interleaveComma(mapper->getArguments(), p,
1741  [&](auto arg) { p.printRegionArgument(arg); });
1742  p << ") ";
1743 
1744  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1745  p.decreaseIndent();
1746  }
1747 }
1748 
1749 LogicalResult ReduceOp::verify() {
1750  ArrayRef<int64_t> dimensionsRef = getDimensions();
1751 
1752  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1753  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1754  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1755  return emitOpError() << "expects all inputs to have the same shapes. "
1756  "Shape at input-index "
1757  << i
1758  << " is not equal to the shape at input-index 0.";
1759  }
1760  }
1761  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1762  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1763  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1764  return emitOpError() << "expects all outputs to have the same shapes. "
1765  "Shape at output-index "
1766  << i
1767  << " is not equal to the shape at output-index 0.";
1768  }
1769  }
1770  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1771  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1772 
1773  DenseSet<int64_t> dimensionsToReduce;
1774  for (int64_t dimension : dimensionsRef) {
1775  if (dimension < 0 || dimension >= inputType.getRank()) {
1776  return emitOpError()
1777  << "dimensions for reduction should be in the range [0, "
1778  << inputType.getRank() - 1 << "].";
1779  }
1780  dimensionsToReduce.insert(dimension);
1781  }
1782 
1783  auto inputDims = inputType.getShape();
1784  auto initDims = initType.getShape();
1785 
1786  // Input dimensions that will be left after the reduction.
1787  SmallVector<int64_t> reducedInputDims;
1788  for (const auto &en : llvm::enumerate(inputDims)) {
1789  if (!dimensionsToReduce.count(en.index()))
1790  reducedInputDims.push_back(en.value());
1791  }
1792 
1793  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1794  return emitOpError() << "number of dimensions after reduction "
1795  << reducedInputDims.size()
1796  << " doesn't match the init rank "
1797  << initType.getRank();
1798  }
1799 
1800  if (reducedInputDims != initDims)
1801  return emitOpError() << "init dimensions [" << initDims
1802  << "] doesn't match input dimensions after reduction ["
1803  << reducedInputDims << "]";
1804 
1805  Block *block = getBody();
1806  if (block->getNumArguments() != this->getNumOperands())
1807  return emitOpError()
1808  << "mismatching number of operands and block arguments";
1809 
1810  // Check that the first block arguments match the element type of the inputs.
1811  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1812  Type inputElementType =
1813  llvm::cast<ShapedType>(input.getType()).getElementType();
1814  if (inputElementType != bbArg.getType())
1815  return emitOpError()
1816  << "input element type " << inputElementType
1817  << " does not match corresponding block argument type "
1818  << bbArg.getType();
1819  }
1820 
1821  // Check that the last block arguments match the element type of the outputs.
1822  for (auto [output, bbArg] : llvm::zip(
1823  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1824  auto outputElementType =
1825  llvm::cast<ShapedType>(output.getType()).getElementType();
1826  if (outputElementType != bbArg.getType())
1827  return emitOpError()
1828  << "output element type " << outputElementType
1829  << " does not match corresponding block argument type "
1830  << bbArg.getType();
1831  }
1832  return success();
1833 }
1834 
1835 //===----------------------------------------------------------------------===//
1836 // TransposeOp
1837 //===----------------------------------------------------------------------===//
1838 
1839 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1840  Region &region, ValueRange inputs,
1841  ValueRange outputs) {
1842  buildGenericRegion(builder, loc, region, inputs, outputs,
1843  [](OpBuilder &b, Location loc, ValueRange args) {
1844  if (!args.empty())
1845  b.create<linalg::YieldOp>(loc, args[0]);
1846  });
1847 }
1848 
1849 void TransposeOp::build(::mlir::OpBuilder &builder,
1850  ::mlir::OperationState &result, Value input, Value init,
1851  DenseI64ArrayAttr permutation,
1852  ArrayRef<NamedAttribute> attributes) {
1853  result.addOperands(input);
1854  result.addOperands(init);
1855  result.addAttribute(getPermutationAttrName(result.name), permutation);
1856  result.addAttributes(attributes);
1857 
1858  // Add output types for `RankedTensorType` output arguments.
1859  Type initType = init.getType();
1860  if (llvm::isa<RankedTensorType>(initType))
1861  result.addTypes(initType);
1862 
1863  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1864  init);
1865 }
1866 
1867 void TransposeOp::build(::mlir::OpBuilder &builder,
1868  ::mlir::OperationState &result, Value input, Value init,
1869  ArrayRef<int64_t> permutation,
1870  ArrayRef<NamedAttribute> attributes) {
1871  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1872  attributes);
1873 }
1874 
1875 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1876  if (failed(parseDstStyleOp(
1877  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1878  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1879  })))
1880  return failure();
1881 
1882  OpBuilder builder(parser.getContext());
1883  buildIdentityRegion(builder, result.location, *result.addRegion(),
1884  /*inputs=*/result.operands,
1885  /*outputs=*/{});
1886  return success();
1887 }
1888 
1889 void TransposeOp::getAsmResultNames(
1890  function_ref<void(Value, StringRef)> setNameFn) {
1891  if (!getResults().empty())
1892  setNameFn(getResults().front(), "transposed");
1893 }
1894 
1896  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1897  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1898  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1899 }
1900 
1901 LogicalResult TransposeOp::verify() {
1902  ArrayRef<int64_t> permutationRef = getPermutation();
1903 
1904  if (!isPermutationVector(permutationRef))
1905  return emitOpError("permutation is not valid");
1906 
1907  auto inputType = getInput().getType();
1908  auto initType = getInit().getType();
1909 
1910  int64_t rank = inputType.getRank();
1911 
1912  if (rank != initType.getRank())
1913  return emitOpError() << "input rank " << rank
1914  << " does not match init rank " << initType.getRank();
1915 
1916  if (rank != static_cast<int64_t>(permutationRef.size()))
1917  return emitOpError() << "size of permutation " << permutationRef.size()
1918  << " does not match the argument rank " << rank;
1919 
1920  auto inputDims = inputType.getShape();
1921  auto initDims = initType.getShape();
1922 
1923  for (int64_t i = 0; i < rank; ++i) {
1924  int64_t inputDim = inputDims[permutationRef[i]];
1925  int64_t initDim = initDims[i];
1926 
1927  if (inputDim != initDim) {
1928  return emitOpError() << "dim(result, " << i << ") = " << initDim
1929  << " doesn't match dim(input, permutation[" << i
1930  << "]) = " << inputDim;
1931  }
1932  }
1933 
1934  return success();
1935 }
1936 
1937 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1938  int64_t rank = getInit().getType().getRank();
1939  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1940 }
1941 
1942 ArrayAttr TransposeOp::getIndexingMaps() {
1943  Builder builder(getContext());
1944  int64_t rank = getInit().getType().getRank();
1945  return builder.getAffineMapArrayAttr(
1947  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1948  builder.getMultiDimIdentityMap(rank)});
1949 }
1950 
1951 void TransposeOp::getEffects(
1953  &effects) {
1954  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1955 }
1956 
1957 Speculation::Speculatability TransposeOp::getSpeculatability() {
1958  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1959 }
1960 
1961 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1963  // Only the tensor type is supported.
1964  if (!isa<TensorType>(getInput().getType()))
1965  return failure();
1966 
1967  // Single dimension transpose.
1968  if (getPermutation().size() == 0) {
1969  result.push_back(getInput());
1970  return success();
1971  }
1972  // Identity permutation.
1973  if (isIdentityPermutation(getPermutation())) {
1974  result.push_back(getInput());
1975  return success();
1976  }
1977 
1978  return failure();
1979 }
1980 
1981 /// Fold transpose with transpose.
1982 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1984 
1985  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1986  PatternRewriter &rewriter) const override {
1987  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1988  if (!defTransposeOp)
1989  return failure();
1990  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
1991  ArrayRef<int64_t> perms = transposeOp.getPermutation();
1992  SmallVector<int64_t> foldedPerms;
1993  foldedPerms.reserve(perms.size());
1994  for (int64_t perm : perms)
1995  foldedPerms.push_back(defPerms[perm]);
1996 
1997  rewriter.replaceOpWithNewOp<TransposeOp>(
1998  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1999  foldedPerms);
2000  return success();
2001  }
2002 };
2003 
2004 /// This pattern canonicalize transpose by swapping the order of
2005 /// broadcast and transpose:
2006 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2007 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2009 
2010  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2011  PatternRewriter &rewriter) const override {
2012  Value input = transposeOp.getInput();
2013  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2014  if (!input.hasOneUse() || !broadcastOp)
2015  return failure();
2016 
2017  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2018  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2019 
2020  // Get new perms and new dimensions.
2021  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2022  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2023  SmallVector<int64_t> resultDimensions;
2024  unsigned dimensionSize = dimensions.size();
2025  for (unsigned i = 0; i < dimensionSize; ++i)
2026  resultDimensions.push_back(invertPerm[dimensions[i]]);
2027 
2028  // Create transpose result.
2029  Value broadcastInput = broadcastOp.getInput();
2030  Location loc = transposeOp.getLoc();
2031  MLIRContext *ctx = transposeOp.getContext();
2033  auto broadcastInputTy =
2034  mlir::cast<RankedTensorType>(broadcastInput.getType());
2035  unsigned inputRank = broadcastInputTy.getRank();
2036  for (unsigned i = 0; i < inputRank; ++i) {
2037  if (broadcastInputTy.isDynamicDim(i)) {
2038  dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2039  ->getResult(0));
2040  } else {
2041  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2042  broadcastInputTy.getDimSize(i)));
2043  }
2044  }
2045  SmallVector<OpFoldResult> transposeResultShapes =
2046  applyPermutation(dims, resultPerms);
2047  Value transposeInit = rewriter.create<tensor::EmptyOp>(
2048  transposeOp.getLoc(), transposeResultShapes,
2049  broadcastInputTy.getElementType());
2050 
2051  // Create broadcast(transpose(input)).
2052  Value transposeResult =
2053  rewriter
2054  .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2055  resultPerms)
2056  ->getResult(0);
2057  rewriter.replaceOpWithNewOp<BroadcastOp>(
2058  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2059  return success();
2060  }
2061 };
2062 
2063 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2064  MLIRContext *context) {
2066 }
2067 
2068 //===----------------------------------------------------------------------===//
2069 // BroadcastOp
2070 //===----------------------------------------------------------------------===//
2071 
2072 void BroadcastOp::build(::mlir::OpBuilder &builder,
2073  ::mlir::OperationState &result, Value input, Value init,
2074  DenseI64ArrayAttr dimensions,
2075  ArrayRef<NamedAttribute> attributes) {
2076  result.addOperands(input);
2077  result.addOperands(init);
2078  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2079  result.addAttributes(attributes);
2080 
2081  // Add output types for `RankedTensorType` output arguments.
2082  Type initType = init.getType();
2083  if (llvm::isa<RankedTensorType>(initType))
2084  result.addTypes(initType);
2085 
2086  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2087  init);
2088 }
2089 
2090 void BroadcastOp::build(::mlir::OpBuilder &builder,
2091  ::mlir::OperationState &result, Value input, Value init,
2092  ArrayRef<int64_t> dimensions,
2093  ArrayRef<NamedAttribute> attributes) {
2094  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2095  attributes);
2096 }
2097 
2098 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2099  if (failed(parseDstStyleOp(
2100  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2101  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2102  })))
2103  return failure();
2104 
2105  OpBuilder builder(parser.getContext());
2106  buildIdentityRegion(builder, result.location, *result.addRegion(),
2107  /*inputs=*/result.operands,
2108  /*outputs=*/{});
2109  return success();
2110 }
2111 
2112 void BroadcastOp::getAsmResultNames(
2113  function_ref<void(Value, StringRef)> setNameFn) {
2114  if (!getResults().empty())
2115  setNameFn(getResults().front(), "broadcasted");
2116 }
2117 
2119  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2120  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2121  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2122 }
2123 
2124 LogicalResult BroadcastOp::verify() {
2125  ArrayRef<int64_t> dimensionsRef = getDimensions();
2126 
2127  auto inputType = getInput().getType();
2128  auto initType = getInit().getType();
2129 
2130  int64_t inputRank = inputType.getRank();
2131  int64_t initRank = initType.getRank();
2132 
2133  auto inputShape = inputType.getShape();
2134  auto initShape = initType.getShape();
2135 
2136  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2137  return emitOpError() << "input rank plus added dimensions does not "
2138  "match init rank. input rank: "
2139  << inputRank
2140  << ", dimensions size: " << dimensionsRef.size()
2141  << ", init rank: " << initRank;
2142 
2143  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2144  if (dim < 0 || dim >= initRank)
2145  return emitOpError() << "dimension " << idx
2146  << " is out of range. expected range: [0, "
2147  << initRank - 1 << "], got: " << dim;
2148  }
2149 
2150  // Mapping from input dims to init dims.
2151  SmallVector<int64_t> dimMap;
2152  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2153  if (!llvm::is_contained(dimensionsRef, dim))
2154  dimMap.push_back(dim);
2155  }
2156 
2157  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2158  // This dimensions is mapped from the input. Init and input dims should
2159  // match.
2160  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2161  return emitOpError() << "input dim " << inputDimIdx
2162  << " should match init dim " << initDimIdx
2163  << ". input: " << inputShape[inputDimIdx]
2164  << ", init: " << initShape[initDimIdx];
2165  }
2166 
2167  return success();
2168 }
2169 
2170 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2171  int64_t rank = getInit().getType().getRank();
2172  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2173 }
2174 
2175 ArrayAttr BroadcastOp::getIndexingMaps() {
2176  Builder builder(getContext());
2177  int64_t rank = getInit().getType().getRank();
2178  return builder.getAffineMapArrayAttr(
2179  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2180  builder.getMultiDimIdentityMap(rank)});
2181 }
2182 
2183 void BroadcastOp::getEffects(
2185  &effects) {
2186  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2187 }
2188 
2189 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2190  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2191 }
2192 
2193 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2194  MLIRContext *context) {
2195  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2196 }
2197 
2198 //===----------------------------------------------------------------------===//
2199 // YieldOp
2200 //===----------------------------------------------------------------------===//
2201 
2203  if (getNumOperands() > 0)
2204  p << ' ' << getOperands();
2205  p.printOptionalAttrDict((*this)->getAttrs());
2206  if (getNumOperands() > 0)
2207  p << " : " << getOperandTypes();
2208 }
2209 
2210 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2212  SmallVector<Type, 2> types;
2213  SMLoc loc = parser.getCurrentLocation();
2214  return failure(parser.parseOperandList(opInfo) ||
2215  parser.parseOptionalAttrDict(result.attributes) ||
2216  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2217  parser.resolveOperands(opInfo, types, loc, result.operands));
2218 }
2219 
2220 // Check the operand number and types must match the element types of the
2221 // LinalgOp interface's shaped operands.
2222 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2223  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2224  return op.emitOpError("expected number of yield values (")
2225  << op.getNumOperands()
2226  << ") to match the number of inits / outs operands of the enclosing "
2227  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2228 
2229  for (OpOperand &opOperand : op->getOpOperands()) {
2230  OpOperand *outputOperand =
2231  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2232  Type elementType = outputOperand->get().getType();
2233  if (isa<MemRefType, RankedTensorType>(elementType))
2234  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2235  if (opOperand.get().getType() != elementType)
2236  return op.emitOpError("type of yield operand ")
2237  << (opOperand.getOperandNumber() + 1) << " ("
2238  << opOperand.get().getType() << ") doesn't match "
2239  << "the element type of the enclosing linalg.generic op ("
2240  << elementType << ")";
2241  }
2242  return success();
2243 }
2244 
2245 LogicalResult linalg::YieldOp::verify() {
2246  auto *parentOp = (*this)->getParentOp();
2247  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2248  return emitOpError("expected single non-empty parent region");
2249 
2250  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2251  return verifyYield(*this, linalgOp);
2252 
2253  return emitOpError("expected parent op with LinalgOp interface");
2254 }
2255 
2256 //===----------------------------------------------------------------------===//
2257 // IndexOp
2258 //===----------------------------------------------------------------------===//
2259 
2260 LogicalResult IndexOp::verify() {
2261  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2262  if (!linalgOp)
2263  return emitOpError("expected parent op with LinalgOp interface");
2264  if (linalgOp.getNumLoops() <= getDim())
2265  return emitOpError("expected dim (")
2266  << getDim() << ") to be lower than the number of loops ("
2267  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2268  return success();
2269 }
2270 
2271 /////// Operations corresponding to library calls defined with Tablegen ////////
2272 
2273 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2274 
2275 #define GET_OP_CLASSES
2276 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2277 
2278 #define GET_OP_CLASSES
2279 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2280 
2281 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2282  unsigned rank,
2283  MLIRContext *context) {
2284  if (maybeMap)
2285  return *maybeMap;
2286  if (rank == 0)
2287  return AffineMap::get(context);
2288  return AffineMap::getMultiDimIdentityMap(rank, context);
2289 }
2290 
2292 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2293  MLIRContext *context) {
2295  res.reserve(num);
2296  for (unsigned i = 0; i < num; ++i)
2297  res.push_back(getAffineDimExpr(startIdx++, context));
2298  return res;
2299 }
2300 
2303  auto rangeA = llvm::make_range(a.begin(), a.end());
2304  auto rangeB = llvm::make_range(b.begin(), b.end());
2305  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2306  return llvm::to_vector<4>(concatRanges);
2307 }
2308 
2309 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2310  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2311  ss << "view";
2312  for (auto size : memref.getShape())
2313  if (size < 0)
2314  ss << "sx";
2315  else
2316  ss << size << "x";
2317  if (failed(appendMangledType(ss, memref.getElementType())))
2318  return failure();
2319  if (auto as = memref.getMemorySpace()) {
2320  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2321  ss << "as" << attr.getInt();
2322  else
2323  return failure();
2324  }
2325  return success();
2326  }
2327  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2328  ss << "vector";
2329  llvm::interleave(
2330  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2331  if (failed(appendMangledType(ss, vec.getElementType())))
2332  return failure();
2333  return success();
2334  }
2335  if (t.isSignlessIntOrIndexOrFloat()) {
2336  ss << t;
2337  return success();
2338  }
2339  return failure();
2340 }
2341 
2343  assert(isa<LinalgOp>(op));
2344  std::string name(op->getName().getStringRef().str());
2345  std::string fun = "";
2346  for (NamedAttribute kv : op->getAttrs()) {
2347  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2348  fun = stringifyEnum(ufa.getValue()).str() + "_";
2349  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2350  fun = stringifyEnum(bfa.getValue()).str() + "_";
2351  }
2352  }
2353  name.reserve(128);
2354  std::replace(name.begin(), name.end(), '.', '_');
2355  llvm::raw_string_ostream ss(name);
2356  ss << "_" << fun;
2357  for (Type t : op->getOperandTypes()) {
2358  if (failed(appendMangledType(ss, t)))
2359  return std::string();
2360  ss << "_";
2361  }
2362  name.pop_back();
2363  return name;
2364 }
2365 
2366 //===----------------------------------------------------------------------===//
2367 // Canonicalizers and Folders.
2368 //===----------------------------------------------------------------------===//
2369 
2370 namespace {
2371 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2373 
2374  LogicalResult matchAndRewrite(LinalgOp op,
2375  PatternRewriter &rewriter) const override {
2376  for (OpOperand &opOperand : op->getOpOperands()) {
2377  // Linalg "inputs" may be either tensor or memref type.
2378  // tensor<0xelt_type> is a convention that may not always mean
2379  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2380  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2381  if (!mt)
2382  continue;
2383  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2384  rewriter.eraseOp(op);
2385  return success();
2386  }
2387  }
2388  return failure();
2389  }
2390 };
2391 
2392 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2393 /// result that is more static than the linalg op.
2394 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2396 
2397  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2398  PatternRewriter &rewriter) const override {
2399  if (!tensor::canFoldIntoProducerOp(castOp))
2400  return failure();
2401 
2402  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2403  if (!linalgOp)
2404  return failure();
2405 
2406  // Cast can be in conditionally reachable region, if which case folding will
2407  // generate invalid code. Only conservatively fold ops in same block for
2408  // now.
2409  if (castOp->getBlock() != linalgOp->getBlock())
2410  return failure();
2411 
2412  OpBuilder::InsertionGuard guard(rewriter);
2413  rewriter.setInsertionPoint(linalgOp);
2414 
2415  Location loc = linalgOp.getLoc();
2416  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2417  unsigned resultNumber = resultValue.getResultNumber();
2418  auto resultType =
2419  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2420  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2421  // going from a more dynamic shape to a less dynamic shape. If the producer
2422  // for this cast, i.e. producer of the out operand, is also an operation
2423  // that folds with tensor.cast consumer (like this pattern), the cast will
2424  // continue to propagate as far up the stack as it can go.
2425  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2426  Value newOperand =
2427  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2428  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2429  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2430  linalgOp.getDpsInits().end());
2431  outputOperands[resultNumber] = newOperand;
2432  newOperands.append(outputOperands.begin(), outputOperands.end());
2433 
2434  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2435  linalgOp->result_type_end());
2436  resultTypes[resultNumber] = resultType;
2437  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2438 
2439  // Create a tensor.cast operation back to the original type.
2440  Value castBack = rewriter.create<tensor::CastOp>(
2441  loc, resultValue.getType(), newOp->getResult(resultNumber));
2442 
2443  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2444  results[resultNumber] = castBack;
2445  rewriter.replaceOp(linalgOp, results);
2446  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2447  return success();
2448  }
2449 };
2450 
2451 /// For each of the operand in `operands` this function maps the static sizes of
2452 /// dimensions to their affine dim expressions.
2453 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2454  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2455  for (OpOperand &opOperand : operands) {
2456  if (linalgOp.isScalar(&opOperand))
2457  continue;
2458  Value src = opOperand.get();
2459  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2460  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2461 
2462  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2463  // `tensor.cast` operation and source of the cast operation has a static
2464  // shape, then assign it to the `sourceShape`.
2465  auto *parentOp = src.getDefiningOp();
2466  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2467  if (parentOp) {
2468  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2469  Value castSource = castOp.getSource();
2470  auto castSourceType =
2471  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2472  if (castSourceType && castSourceType.hasStaticShape())
2473  sourceShape = castSourceType.getShape();
2474  }
2475  }
2476 
2477  // If the source shape's dimension has a static shape, map the affine dim
2478  // expression to the known static size.
2479  for (unsigned i = 0; i < sourceShape.size(); i++) {
2480  if (sourceType.isDynamicDim(i))
2481  continue;
2482  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2483  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2484  }
2485  }
2486 }
2487 
2488 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2489 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2490 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2491 /// change then `changeNeeded` is false and same operand is added in the
2492 /// `newOperands` list.
2493 static void createNewOperandWithStaticSizes(
2494  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2495  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2496  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2497  bool &changeNeeded) {
2498  Value src = opOperand->get();
2499  newOperands.push_back(src);
2500  if (linalgOp.isScalar(opOperand))
2501  return;
2502  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2503  Type resultType = sourceType;
2504  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2505  resultTypes.push_back(resultType);
2506  return;
2507  }
2508  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2509  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2510  SmallVector<int64_t> newShape;
2511  // If operand is updated with new shape, `newOperandNeeded` will be
2512  // true.
2513  bool newOperandNeeded = false;
2514  for (unsigned i = 0; i < sourceShape.size(); i++) {
2515  int64_t dimShape = sourceShape[i];
2516  AffineExpr dimExpr = sourceMap.getResult(i);
2517  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2518  newShape.push_back(dimShape);
2519  continue;
2520  }
2521  // Dimension has a dynamic shape and corresponding affine dim
2522  // expression is present in the map. So assign the size for the
2523  // given affine dim expression to the dimension.
2524  newShape.push_back(affineExprToSize[dimExpr]);
2525  newOperandNeeded = true;
2526  }
2527  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2528  if (newOperandNeeded) {
2529  changeNeeded = true;
2530  // Get the new operand value given its size and element type by
2531  // casting it.
2532  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2533  unsigned index = opOperand->getOperandNumber();
2534  newOperands[index] = newOperand;
2535  }
2536  if (linalgOp.isDpsInit(opOperand))
2537  resultTypes.push_back(resultType);
2538 }
2539 
2540 /// Static shapes for the operands can be inferred if any one of the operands
2541 /// have a static shape. This can be done by referring to the affine dim
2542 /// expressions for the operand.
2543 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2545 
2546  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2547  PatternRewriter &rewriter) const override {
2548  if (!linalgOp.hasPureTensorSemantics())
2549  return failure();
2550 
2551  // Maps must be projected permutations.
2552  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2553  return !map.isProjectedPermutation();
2554  }))
2555  return failure();
2556 
2557  // Maps affine dim expressions to the static size of that dimension.
2558  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2559  Location loc = linalgOp.getLoc();
2560 
2561  // For each of the affine dim expression, check if the size is known. If
2562  // known add that in the map.
2563  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2564 
2565  SmallVector<Value> newOperands;
2566  SmallVector<Type> resultTypes;
2567 
2568  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2569  // change in their types.
2570  bool changeNeeded = false;
2571  newOperands.reserve(linalgOp->getNumOperands());
2572  resultTypes.reserve(linalgOp.getNumDpsInits());
2573 
2574  // Iterate over all the operands and update the static sizes.
2575  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2576  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2577  affineExprToSize, linalgOp, newOperands,
2578  resultTypes, changeNeeded);
2579  }
2580 
2581  // If the generic op has all the required static information, no
2582  // canonicalization needed.
2583  if (!changeNeeded)
2584  return failure();
2585 
2586  // Clone op.
2587  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2588  SmallVector<Value> replacements;
2589  replacements.reserve(newOp->getNumResults());
2590  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2591  Value newResult = std::get<1>(it);
2592  Value oldResult = std::get<0>(it);
2593  Type newType = newResult.getType();
2594  Type oldType = oldResult.getType();
2595  replacements.push_back(
2596  (newType != oldType)
2597  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2598  : newResult);
2599  }
2600  rewriter.replaceOp(linalgOp, replacements);
2601  return success();
2602  }
2603 };
2604 
2605 } // namespace
2606 
2607 // All named ops canonicalizers and folders are auto-generated in the
2608 // .cpp.inc.
2609 
2610 //===----------------------------------------------------------------------===//
2611 // SoftmaxOp
2612 //===----------------------------------------------------------------------===//
2613 
2614 LogicalResult SoftmaxOp::verify() {
2615  ShapedType inputType = getInputOperandType();
2616  ShapedType outputType = getOutputOperandType();
2617 
2618  ArrayRef<int64_t> inputShape = inputType.getShape();
2619  ArrayRef<int64_t> outputShape = outputType.getShape();
2620  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2621  return emitOpError("incompatible output shape");
2622 
2623  int64_t inputRank = getInputOperandRank();
2624  int64_t dimension = getDimension();
2625  if ((dimension < 0) || (dimension >= inputRank))
2626  return emitOpError("incorrect dimension specified");
2627 
2628  return success();
2629 }
2630 
2631 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2632  int64_t operandRank = getInputOperandRank();
2633  SmallVector<Range> loopBounds(operandRank);
2634  Location loc = getLoc();
2635  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2636  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2637  Value source = getInput();
2638  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2639  loopBounds[dim].offset = zero;
2640  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2641  loopBounds[dim].stride = one;
2642  }
2643  return loopBounds;
2644 }
2645 
2646 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2647  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2648  utils::IteratorType::parallel);
2649  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2650  return iteratorTypes;
2651 }
2652 
2653 FailureOr<TilingResult>
2654 SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2655  ArrayRef<OpFoldResult> offsets,
2656  ArrayRef<OpFoldResult> sizes) {
2657  int64_t rank = getInputOperandRank();
2658  auto oneAttr = builder.getI64IntegerAttr(1);
2659  SmallVector<OpFoldResult> strides(rank, oneAttr);
2660  SmallVector<Value> tiledOperands;
2661  Operation *inputSlice =
2662  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2663  if (!inputSlice) {
2664  return emitOpError("failed to compute input slice");
2665  }
2666  tiledOperands.emplace_back(inputSlice->getResult(0));
2667  Operation *outputSlice =
2668  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2669  if (!outputSlice) {
2670  return emitOpError("failed to compute output slice");
2671  }
2672  tiledOperands.emplace_back(outputSlice->getResult(0));
2673 
2674  SmallVector<Type, 4> resultTypes;
2675  if (hasPureTensorSemantics())
2676  resultTypes.push_back(tiledOperands[1].getType());
2677  Operation *tiledOp =
2678  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2679 
2680  return TilingResult{
2681  {tiledOp},
2682  SmallVector<Value>(tiledOp->getResults()),
2683  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2684 }
2685 
2686 LogicalResult SoftmaxOp::getResultTilePosition(
2687  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2688  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2689  SmallVector<OpFoldResult> &resultSizes) {
2690  if (resultNumber == 0) {
2691  resultOffsets.assign(offsets.begin(), offsets.end());
2692  resultSizes.assign(sizes.begin(), sizes.end());
2693  return success();
2694  }
2695  return failure();
2696 }
2697 
2698 // cast(dynamic) -> static.
2699 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2700  return memref::foldMemRefCast(*this);
2701 }
2702 
2703 LogicalResult
2705  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2707  Location loc = getOperation()->getLoc();
2708  IRRewriter rewriter(b);
2709  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2710  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2711  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2712  if (!outputShapedType.isDynamicDim(dim)) {
2713  // Static dim: Return IntegerAttr.
2714  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2715  } else {
2716  // Dynamic dim: Return Value.
2717  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2718  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2719  }
2720  }
2721  reifiedReturnShapes.emplace_back(std::move(shapes));
2722  return success();
2723 }
2724 
2725 void SoftmaxOp::getEffects(
2727  &effects) {
2728  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2729  if (!llvm::isa<MemRefType>(operand.getType()))
2730  continue;
2731  effects.emplace_back(MemoryEffects::Read::get(),
2732  &getOperation()->getOpOperand(index), /*stage=*/0,
2733  /*effectOnFullRegion=*/true,
2735  }
2736 
2737  for (OpOperand &operand : getDpsInitsMutable()) {
2738  if (!llvm::isa<MemRefType>(operand.get().getType()))
2739  continue;
2740  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2741  /*effectOnFullRegion=*/true,
2743  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2744  /*effectOnFullRegion=*/true,
2746  }
2747 }
2748 
2749 // Helper functions for softmax decomposition.
2750 // @{
2751 
2752 // Helper function to produce the iterator types (reduction or parallel) and
2753 // affine maps for the iterators used in the decomposition of softmax.
2754 // This method creates:
2755 // If allParallel == true:
2756 // - iterator type: {parallel, ..., parallel}
2757 // - affine maps:
2758 // -- identity with inputRank dimensions.
2759 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2760 // where N == inputRank.
2761 //
2762 // If allParallel == false:
2763 // - iterator type at dim(i) == parallel for i != \p dim and
2764 // dim(dim) == reduction.
2765 // - affine map:
2766 // -- identity with inputRank dimensions.
2767 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2768 // where N == inputRank.
2769 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2771  int64_t dim, bool allParallel = false) {
2772  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2773  utils::IteratorType::parallel);
2774  if (!allParallel)
2775  iteratorTypes[dim] = utils::IteratorType::reduction;
2776  MLIRContext *ctxt = builder.getContext();
2777  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2778  SmallVector<AffineExpr, 2> affineExprs;
2779  for (int i = 0; i < inputRank; i++) {
2780  if (i != dim)
2781  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2782  }
2783  auto reductionMap =
2784  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2785  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2786  return std::make_tuple(iteratorTypes, indexingMaps);
2787 }
2788 
2789 // Helper function to produce a linalg.generic that computes a reduction on
2790 // dimension \p dim with the operation type \p T.
2791 template <typename T>
2792 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2793  int64_t dim) {
2794  auto inputType = cast<ShapedType>(input.getType());
2795  ArrayRef<int64_t> inputShape = inputType.getShape();
2796  int64_t inputRank = inputShape.size();
2797  auto [iteratorTypes, indexingMaps] =
2798  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2799  assert(indexingMaps.size() == 2 &&
2800  "We should have two maps: 1 for the input, 1 for the output");
2801  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2802 
2803  auto genericOp = builder.create<linalg::GenericOp>(
2804  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2805  [&](OpBuilder &b, Location loc, ValueRange args) {
2806  Value result = b.create<T>(loc, args[0], args[1]);
2807  b.create<linalg::YieldOp>(loc, result);
2808  });
2809  return genericOp.getResult(0);
2810 }
2811 
2812 /// Produce a linalg generic that computes the second step of the softmax
2813 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2814 /// on dimension \p dim.
2815 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2816  Value max, Value output, int64_t dim) {
2817  auto inputType = cast<ShapedType>(input.getType());
2818  ArrayRef<int64_t> inputShape = inputType.getShape();
2819  int64_t inputRank = inputShape.size();
2820  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2821  builder, inputRank, dim, /*allParallel=*/true);
2822  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2823  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2824  // Add the affine map for the output argument.
2825  indexingMaps.push_back(indexingMaps[0]);
2826  auto genericOp = builder.create<linalg::GenericOp>(
2827  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2828  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2829  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2830  Value result = b.create<math::ExpOp>(loc, diff);
2831  b.create<linalg::YieldOp>(loc, result);
2832  });
2833  return genericOp.getResult(0);
2834 }
2835 
2836 /// Produce a linalg generic that computes the final step of the softmax
2837 /// decomposition.
2838 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2839 /// yield n / d
2840 /// }
2841 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2842  Value denominator, Value output, int64_t dim) {
2843  auto inputType = cast<ShapedType>(numerator.getType());
2844  ArrayRef<int64_t> inputShape = inputType.getShape();
2845  int64_t inputRank = inputShape.size();
2846  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2847  builder, inputRank, dim, /*allParallel=*/true);
2848  assert(indexingMaps.size() == 2 &&
2849  "We should have one map for each input (2)");
2850  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2851  // Add the affine map for the output tensor.
2852  indexingMaps.push_back(indexingMaps[0]);
2853  auto genericOp = builder.create<linalg::GenericOp>(
2854  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2855  indexingMaps, iteratorTypes,
2856  [&](OpBuilder &b, Location loc, ValueRange args) {
2857  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2858  b.create<linalg::YieldOp>(loc, result);
2859  });
2860  return genericOp.getResult(0);
2861 }
2862 // @} End helper functions for softmax decomposition.
2863 
2864 /// Given an N-dimensional tensor x, this method converts
2865 /// softmax(x) to the following sequence of operations:
2866 ///
2867 /// 1. Compute the max of x along dimension d. This results
2868 /// in a N-1 dimensional tensor m.
2869 /// m = max(x, dim = d)
2870 ///
2871 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2872 /// a N dimensional tensor z.
2873 /// z = exp(x - m)
2874 ///
2875 /// 3. Compute the sum of z along dimension d. This results in
2876 /// a N-1 dimensional tensor l.
2877 /// l = sum(z, dim = d)
2878 ///
2879 /// 4. Divide z and l. This gives the N-dimensional softmax.
2880 /// softmax = z / l
2881 ///
2882 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2883  OpBuilder::InsertionGuard guard(b);
2884  b.setInsertionPoint(*this);
2885  Location loc = getLoc();
2886  Value input = getInput();
2887  ShapedType inputType = getInputOperandType();
2888  Type elementType = inputType.getElementType();
2889  int64_t reductionDim = getDimension();
2890  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2891  Value output = getOutput();
2892  dims.erase(dims.begin() + reductionDim);
2893  // Step 1: Compute max along dim.
2894  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2895  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2896  elementType, b, loc,
2897  /*useOnlyFiniteValue=*/true);
2898  Value neutralForMaxFInit =
2899  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2900  .result();
2901  Value max =
2902  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2903 
2904  // Step 2: Subtract max from input and exponentiate.
2905  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2906 
2907  // Step 3: Compute sum along dim.
2908  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2909  b, loc, /*useOnlyFiniteValue=*/true);
2910  Value zeroInit =
2911  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2912  Value denominator =
2913  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2914 
2915  // Step 4: Compute softmax.
2916  Value result =
2917  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2918  return SmallVector<Value>{result};
2919 }
2920 
2921 //===----------------------------------------------------------------------===//
2922 // WinogradFilterTransformOp
2923 //===----------------------------------------------------------------------===//
2924 
2925 LogicalResult WinogradFilterTransformOp::verify() {
2926  auto filterType = cast<ShapedType>(getFilter().getType());
2927  ArrayRef<int64_t> filterShape = filterType.getShape();
2928  int64_t filterH = filterShape[getFilterHDim()];
2929  int64_t filterW = filterShape[getFilterWDim()];
2930  int64_t r = getR();
2931  int64_t m = getM();
2932 
2933  if (filterH != r && filterH != 1)
2934  return emitOpError("expect filter height either equals to r or 1");
2935  if (filterW != r && filterW != 1)
2936  return emitOpError("expect filter width either equals to r or 1");
2937  if (filterH == 1 && filterW == 1)
2938  return emitOpError("expect either filter height or width equals to r");
2939 
2940  SmallVector<int64_t> expectedOutputShape;
2941  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2942  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2943  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2944  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2945 
2946  auto outputType = cast<ShapedType>(getOutput().getType());
2947  ArrayRef<int64_t> outputShape = outputType.getShape();
2948  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2949  return emitOpError("the output shape is not expected");
2950  }
2951  return success();
2952 }
2953 
2955 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
2956  Location loc = getLoc();
2957  IntegerAttr zeroAttr = builder.getIndexAttr(0);
2958  IntegerAttr oneAttr = builder.getIndexAttr(1);
2959  Value filter = getFilter();
2960  int64_t filterRank = getFilterOperandRank();
2961  SmallVector<Range> loopBounds(filterRank);
2962  for (unsigned dim = 0; dim < filterRank; ++dim) {
2963  loopBounds[dim].offset = zeroAttr;
2964  loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
2965  loopBounds[dim].stride = oneAttr;
2966  }
2967  return loopBounds;
2968 }
2969 
2971 WinogradFilterTransformOp::getLoopIteratorTypes() {
2972  int64_t filterRank = getFilterOperandRank();
2973  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
2974  utils::IteratorType::parallel);
2975  return iteratorTypes;
2976 }
2977 
2978 LogicalResult WinogradFilterTransformOp::getResultTilePosition(
2979  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2980  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2981  SmallVector<OpFoldResult> &resultSizes) {
2982  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2983  ShapedType filterType = getFilterOperandType();
2984  ArrayRef<int64_t> filterShape = filterType.getShape();
2985  int64_t filterH = filterShape[getFilterHDim()];
2986  int64_t filterW = filterShape[getFilterWDim()];
2987  int64_t m = getM();
2988  int64_t r = getR();
2989  int64_t alpha = m + r - 1;
2990  int64_t alphaH = filterH != 1 ? alpha : 1;
2991  int64_t alphaW = filterW != 1 ? alpha : 1;
2992  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
2993  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
2994 
2995  resultOffsets.append(
2996  {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
2997  resultSizes.append(
2998  {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
2999 
3000  return success();
3001 }
3002 
3003 /// Implement tiling for winograd_filter_transform
3004 /// The input of winograd_filter_transform is (F, KH, KW, C).
3005 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3006 /// Users can specify the tile sizes of F and C.
3007 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3008 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3009 FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3010  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3011  ArrayRef<OpFoldResult> sizes) {
3012  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3013  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3014  ShapedType filterType = getFilterOperandType();
3015  ArrayRef<int64_t> filterShape = filterType.getShape();
3016  int64_t filterH = filterShape[getFilterHDim()];
3017  int64_t filterW = filterShape[getFilterWDim()];
3018  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3019  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3020  SmallVector<Value> tiledOperands;
3021  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3022 
3023  sliceOffsets.append(
3024  {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3025  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3026  sizes[getFilterCDim()]});
3027  int64_t filterRank = getFilterOperandRank();
3028  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3029  Location loc = getLoc();
3030  auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3031  loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3032  tiledOperands.emplace_back(filterSlice);
3033 
3034  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3035  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3036  resultSizes)))
3037  return failure();
3038 
3039  int64_t outputRank = getOutputOperandRank();
3040  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3041  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3042  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3043  tiledOperands.emplace_back(outputSlice);
3044 
3045  SmallVector<Type> resultTypes;
3046  resultTypes.push_back(tiledOperands[1].getType());
3047  Operation *tiledOp =
3048  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3049 
3050  return TilingResult{
3051  {tiledOp},
3052  SmallVector<Value>(tiledOp->getResults()),
3053  llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3054 }
3055 
3056 //===----------------------------------------------------------------------===//
3057 // WinogradInputTransformOp
3058 //===----------------------------------------------------------------------===//
3059 
3060 LogicalResult WinogradInputTransformOp::verify() {
3061  auto inputType = cast<ShapedType>(getInput().getType());
3062  ArrayRef<int64_t> inputShape = inputType.getShape();
3063  int64_t inputH = inputShape[getInputHDim()];
3064  int64_t inputW = inputShape[getInputWDim()];
3065  int m = getM();
3066  int r = getR();
3067  int64_t tileSize = m + r - 1;
3068  bool leftTransform = inputH != 1;
3069  bool rightTransform = inputW != 1;
3070 
3071  SmallVector<int64_t> expectedOutputShape(6, inputH);
3072  if (ShapedType::isDynamic(inputH)) {
3073  expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3074  expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3075  } else {
3076  expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3077  expectedOutputShape[getOutputTileHDim()] =
3078  leftTransform ? (inputH - (r - 1)) / m : 1;
3079  }
3080  if (ShapedType::isDynamic(inputW)) {
3081  expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3082  expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3083  } else {
3084  expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3085  expectedOutputShape[getOutputTileWDim()] =
3086  rightTransform ? (inputW - (r - 1)) / m : 1;
3087  }
3088  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3089  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3090 
3091  auto outputType = cast<ShapedType>(getOutput().getType());
3092  ArrayRef<int64_t> outputShape = outputType.getShape();
3093  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3094  return emitOpError("the output shape is not expected");
3095  }
3096  return success();
3097 }
3098 
3100 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3101  Location loc = getLoc();
3102  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3103  IntegerAttr oneAttr = builder.getIndexAttr(1);
3104  Value output = getOutput();
3105  int64_t outputRank = getOutputOperandRank();
3106  SmallVector<Range> loopBounds(outputRank);
3107  for (unsigned dim = 0; dim < outputRank; ++dim) {
3108  loopBounds[dim].offset = zeroAttr;
3109  // alphaH, alphaW, tileH, tileW, N, C
3110  loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3111  loopBounds[dim].stride = oneAttr;
3112  }
3113  return loopBounds;
3114 }
3115 
3117 WinogradInputTransformOp::getLoopIteratorTypes() {
3118  int64_t outputRank = getOutputOperandRank();
3119  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3120  utils::IteratorType::parallel);
3121  return iteratorTypes;
3122 }
3123 
3124 LogicalResult WinogradInputTransformOp::getResultTilePosition(
3125  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3126  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3127  SmallVector<OpFoldResult> &resultSizes) {
3128  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3129  ShapedType inputType = getInputOperandType();
3130  ArrayRef<int64_t> inputShape = inputType.getShape();
3131  int64_t inputH = inputShape[getInputHDim()];
3132  int64_t inputW = inputShape[getInputWDim()];
3133  int64_t m = getM();
3134  int64_t r = getR();
3135  int64_t alpha = m + r - 1;
3136  int64_t alphaH = inputH != 1 ? alpha : 1;
3137  int64_t alphaW = inputW != 1 ? alpha : 1;
3138  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3139  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3140 
3141  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3142  offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3143  offsets[getOutputCDim()]});
3144  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3145  sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3146  sizes[getOutputCDim()]});
3147 
3148  return success();
3149 }
3150 
3151 /// Implement tiling for winograd_input_transform
3152 /// The input of winograd_input_transform is (N, H, W, C).
3153 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3154 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3155 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3156 /// the values for the sizes of tileH, tileW, N, C for one tile.
3157 FailureOr<TilingResult>
3158 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3159  ArrayRef<OpFoldResult> offsets,
3160  ArrayRef<OpFoldResult> sizes) {
3161  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3162  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3163  ShapedType inputType = getInputOperandType();
3164  ArrayRef<int64_t> inputShape = inputType.getShape();
3165  int64_t inputH = inputShape[getInputHDim()];
3166  int64_t inputW = inputShape[getInputWDim()];
3167  int64_t m = getM();
3168  int64_t r = getR();
3169 
3170  Location loc = getLoc();
3171  MLIRContext *context = builder.getContext();
3172  auto offsetAffineMap =
3173  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3174  Value mappedOffsetH = affine::makeComposedAffineApply(
3175  builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]);
3176  Value mappedOffsetW = affine::makeComposedAffineApply(
3177  builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]);
3178  auto sizeAffineMap = AffineMap::get(
3179  1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3180  Value mappedSizeH = affine::makeComposedAffineApply(
3181  builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3182  Value mappedSizeW = affine::makeComposedAffineApply(
3183  builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3184 
3185  SmallVector<Value> tiledOperands;
3186  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3187 
3188  OpFoldResult offsetH =
3189  inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3190  OpFoldResult offsetW =
3191  inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3192  sliceOffsets.append(
3193  {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3194  OpFoldResult sizeH =
3195  inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3196  OpFoldResult sizeW =
3197  inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3198  sliceSizes.append(
3199  {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3200  int64_t inputRank = getInputOperandRank();
3201  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3202  auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3203  loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3204  tiledOperands.emplace_back(inputSlice);
3205 
3206  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3207  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3208  resultSizes)))
3209  return failure();
3210 
3211  int64_t outputRank = getOutputOperandRank();
3212  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3213  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3214  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3215  tiledOperands.emplace_back(outputSlice);
3216 
3217  SmallVector<Type> resultTypes;
3218  resultTypes.push_back(tiledOperands[1].getType());
3219  Operation *tiledOp =
3220  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3221 
3222  return TilingResult{
3223  {tiledOp},
3224  SmallVector<Value>(tiledOp->getResults()),
3225  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3226 }
3227 
3228 //===----------------------------------------------------------------------===//
3229 // WinogradOutputTransformOp
3230 //===----------------------------------------------------------------------===//
3231 
3232 LogicalResult WinogradOutputTransformOp::verify() {
3233  auto valueType = cast<ShapedType>(getValue().getType());
3234  ArrayRef<int64_t> valueShape = valueType.getShape();
3235  int64_t valueH = valueShape[getValueAlphaHDim()];
3236  int64_t valueW = valueShape[getValueAlphaWDim()];
3237  int64_t valueTileH = valueShape[getValueTileHDim()];
3238  int64_t valueTileW = valueShape[getValueTileWDim()];
3239  int m = getM();
3240  int r = getR();
3241  bool leftTransform = valueH != 1;
3242  bool rightTransform = valueW != 1;
3243 
3244  int64_t outputRank = getOutputOperandRank();
3245  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3246  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3247  expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3248  } else {
3249  if (valueH != (leftTransform ? m + r - 1 : 1))
3250  return emitOpError("expect input height equals to input tile size");
3251  expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3252  }
3253  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3254  expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3255  } else {
3256  if (valueW != (rightTransform ? m + r - 1 : 1))
3257  return emitOpError("expect input width equals to input tile size");
3258  expectedOutputShape[getOutputWDim()] =
3259  (rightTransform ? m : 1) * valueTileW;
3260  }
3261  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3262  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3263 
3264  auto outputType = cast<ShapedType>(getOutput().getType());
3265  ArrayRef<int64_t> outputShape = outputType.getShape();
3266  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3267  return emitOpError("the output shape is not expected");
3268  }
3269  return success();
3270 }
3271 
3273 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3274  Location loc = getLoc();
3275  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3276  IntegerAttr oneAttr = builder.getIndexAttr(1);
3277  Value value = getValue();
3278  int64_t valueRank = getValueOperandRank();
3279  SmallVector<Range> loopBounds(valueRank);
3280  for (unsigned dim = 0; dim < valueRank; ++dim) {
3281  loopBounds[dim].offset = zeroAttr;
3282  // alphaH, alphaW, tileH, tileW, N, F
3283  loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3284  loopBounds[dim].stride = oneAttr;
3285  }
3286  return loopBounds;
3287 }
3288 
3290 WinogradOutputTransformOp::getLoopIteratorTypes() {
3291  int64_t valueRank = getValueOperandRank();
3292  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3293  utils::IteratorType::parallel);
3294  return iteratorTypes;
3295 }
3296 
3297 LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3298  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3299  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3300  SmallVector<OpFoldResult> &resultSizes) {
3301  int64_t m = getM();
3302 
3303  Location loc = getLoc();
3304  MLIRContext *context = builder.getContext();
3305  auto affineMap =
3306  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3307 
3308  Value mappedOffsetH = affine::makeComposedAffineApply(
3309  builder, loc, affineMap, offsets[getValueTileHDim()]);
3310  Value mappedOffsetW = affine::makeComposedAffineApply(
3311  builder, loc, affineMap, offsets[getValueTileWDim()]);
3312  Value mappedSizeH = affine::makeComposedAffineApply(
3313  builder, loc, affineMap, sizes[getValueTileHDim()]);
3314  Value mappedSizeW = affine::makeComposedAffineApply(
3315  builder, loc, affineMap, sizes[getValueTileWDim()]);
3316 
3317  ShapedType valueType = getValueOperandType();
3318  ArrayRef<int64_t> valueShape = valueType.getShape();
3319  int64_t valueH = valueShape[0];
3320  int64_t valueW = valueShape[1];
3321  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3322  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3323  OpFoldResult offsetH =
3324  valueH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3325  OpFoldResult offsetW =
3326  valueW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3327  OpFoldResult sizeH =
3328  valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3329  OpFoldResult sizeW =
3330  valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3331 
3332  resultOffsets.append(
3333  {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3334  resultSizes.append(
3335  {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3336  return success();
3337 }
3338 
3339 /// Implement tiling for winograd_output_transform
3340 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3341 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3342 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3343 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3344 /// for the sizes of tileH, tileW, N, F for one tile.
3345 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3346  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3347  ArrayRef<OpFoldResult> sizes) {
3348  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3349  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3350  Location loc = getLoc();
3351  SmallVector<Value> tiledOperands;
3352  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3353 
3354  ShapedType valueType = getValueOperandType();
3355  ArrayRef<int64_t> valueShape = valueType.getShape();
3356  int64_t alphaH = valueShape[getValueAlphaHDim()];
3357  int64_t alphaW = valueShape[getValueAlphaWDim()];
3358  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3359  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3360 
3361  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3362  offsets[getValueTileWDim()], offsets[getValueNDim()],
3363  offsets[getValueFDim()]});
3364  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3365  sizes[getValueTileWDim()], sizes[getValueNDim()],
3366  sizes[getValueFDim()]});
3367  int64_t valueRank = getValueOperandRank();
3368  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3369  auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3370  loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3371  tiledOperands.emplace_back(valueSlice);
3372 
3373  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3374  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3375  resultSizes)))
3376  return failure();
3377 
3378  int64_t outputRank = getOutputOperandRank();
3379  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3380  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3381  loc, getOutput(), resultOffsets, resultSizes, strides);
3382  tiledOperands.emplace_back(outputSlice);
3383 
3384  SmallVector<Type> resultTypes;
3385  resultTypes.push_back(tiledOperands[1].getType());
3386  Operation *tiledOp =
3387  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3388 
3389  return TilingResult{
3390  {tiledOp},
3391  SmallVector<Value>(tiledOp->getResults()),
3392  llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3393 }
3394 
3395 //===----------------------------------------------------------------------===//
3396 // LinalgDialect
3397 //===----------------------------------------------------------------------===//
3398 
3399 void LinalgDialect::getCanonicalizationPatterns(
3400  RewritePatternSet &results) const {
3401  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
3402  InferStaticShapeOfOperands>(getContext());
3403 }
3404 
3406  Attribute value, Type type,
3407  Location loc) {
3408  return arith::ConstantOp::materialize(builder, value, type, loc);
3409 }
3410 
3411 /// Returns true if the result AffineExpr of the \p explicitMap is same as \p
3412 /// defaultMap.
3413 static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
3414  auto explicitRange = explictMap.getResults();
3415  auto defaultRange = defaultMap.getResults();
3416  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3417  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3418  llvm::set_union(explicitSet, defaultSet);
3419  return explicitSet == defaultSet;
3420 }
3421 
3422 /// Returns true if the \p explictMap is broadcasted with respect to the
3423 /// \p defaultMap.
3424 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3425  return explictMap.getNumResults() < defaultMap.getNumResults();
3426 }
3427 
3428 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3429 /// indexing map for the MatmulOp \p op for each operand specified by \p
3430 /// opIndex.
3431 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3432  unsigned opIndex) {
3433  SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3434  SmallVector<AffineMap, 3> defaultIndexingMaps =
3435  matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3436 
3437  auto opIndexingMap = opIndexingMaps[opIndex];
3438  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3439  // Check general validity of indexing map results.
3440  if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
3441  return matmulOp->emitOpError()
3442  << "Unexpected dim expression in map result.";
3443 
3444  // Check if the requested broadcast is valid.
3445  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3446  if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3447  return matmulOp->emitOpError()
3448  << "Invalid broadcast requested, should be (d2).";
3449  }
3450  return success();
3451  }
3452  return success();
3453 }
3454 
3455 namespace mlir {
3456 namespace linalg {
3457 
3458 //===----------------------------------------------------------------------===//
3459 // MatMulOp
3460 //===----------------------------------------------------------------------===//
3461 
3462 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3463 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3464  AffineExpr d0, d1, d2;
3465  SmallVector<AffineMap> indexingMaps;
3466  bindDims(context, d0, d1, d2);
3467  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3468  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3469  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3470  return indexingMaps;
3471 }
3472 
3473 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3474  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3475  utils::IteratorType::parallel,
3476  utils::IteratorType::reduction};
3477 }
3478 
3479 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3480 
3481 std::string MatmulOp::getLibraryCallName() {
3482  return generateLibraryCallName(getOperation());
3483 }
3484 
3485 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3486 
3487 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3488 /// the user defined indexing maps are not equal to default map.
3489 bool MatmulOp::hasUserDefinedMaps() {
3490  SmallVector<AffineMap, 3> defaultMaps =
3491  getDefaultIndexingMaps(this->getContext());
3492  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3493  return defaultMaps != explicitMaps;
3494 }
3495 
3496 /// Implements the block region builder for the MatmulOp. This is called by
3497 /// 'fillStructuredOpRegion'.
3498 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3499  ArrayRef<NamedAttribute> attrs) {
3500  assert(3 > 0 && block.getNumArguments() == 3 &&
3501  "MatmulOp regionBuilder expects 3 (>=0) args");
3502  RegionBuilderHelper helper(b, block);
3503  SmallVector<Value> yields;
3504 
3505  TypeFn castVal = TypeFn::cast_signed;
3506  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3507  return attr.getName() == "cast";
3508  });
3509  if (castIter != attrs.end()) {
3510  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3511  castVal = attr.getValue();
3512  }
3513 
3514  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3515  block.getArgument(0));
3516  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3517  block.getArgument(1));
3518  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3519  Value value4 =
3520  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3521  yields.push_back(value4);
3522  helper.yieldOutputs(yields);
3523 }
3524 
3525 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
3526 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3527  assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3528  AffineExpr exp = bcastMap.getResult(0);
3529  // Invalid map if the common dimension of matmul not found.
3530  return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
3531 }
3532 
3533 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3534  SmallVector<Attribute, 3> indexingMapsAttr;
3535  Attribute mapAttr;
3536  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
3537  if (parser.parseEqual())
3538  return failure();
3539 
3540  if (parser.parseLSquare())
3541  return failure();
3542 
3543  do {
3544  if (parser.parseAttribute(mapAttr))
3545  return failure();
3546  if (!isa<AffineMapAttr>(mapAttr)) {
3547  return parser.emitError(parser.getCurrentLocation(),
3548  "expected affine map attribute");
3549  }
3550  indexingMapsAttr.push_back(mapAttr);
3551 
3552  if (parser.parseOptionalComma())
3553  break;
3554  } while (true);
3555 
3556  if (parser.parseRSquare())
3557  return failure();
3558  }
3559  // Initialize indexingMaps, if not supplied explicitly.
3560  if (indexingMapsAttr.empty()) {
3561  indexingMapsAttr = llvm::map_to_vector(
3562  MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3563  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3564  }
3565  result.addAttribute("indexing_maps",
3566  parser.getBuilder().getArrayAttr(indexingMapsAttr));
3567 
3568  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3569  MatmulOp::getRegionBuilder());
3570 }
3571 void MatmulOp::print(OpAsmPrinter &p) {
3572  SmallVector<StringRef, 3> elidedAttrs = {
3573  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3574  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3575  elidedAttrs);
3576 
3577  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
3578  MatmulOp::getDefaultIndexingMaps(getContext()),
3579  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3580  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
3581  p << " indexing_maps = [";
3582  llvm::interleaveComma(getIndexingMaps(), p,
3583  [&](Attribute attr) { p.printAttribute(attr); });
3584  p << "]";
3585  }
3586 }
3587 
3588 /// Verify the user defined indexing maps.
3589 LogicalResult MatmulOp::verify() {
3590  // Verification of pure matmul is handled by verifyStructuredOpInterface().
3591  if (!hasUserDefinedMaps())
3592  return success();
3593 
3594  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3595  if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3596  return failure();
3597  }
3598  return success();
3599 }
3600 
3601 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3602  return memref::foldMemRefCast(*this);
3603 }
3604 void MatmulOp::getEffects(
3606  &effects) {
3607  if (hasPureTensorSemantics())
3608  return;
3609  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3610 }
3611 
3612 Speculation::Speculatability MatmulOp::getSpeculatability() {
3613  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3614 }
3615 
3616 } // namespace linalg
3617 } // namespace mlir
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
Definition: LinalgOps.cpp:3431
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1839
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:295
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2770
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2841
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:126
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2309
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Returns true if the explictMap is broadcasted with respect to the defaultMap.
Definition: LinalgOps.cpp:3424
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:283
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1672
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1720
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2815
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:190
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1470
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:162
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1339
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2792
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
Definition: LinalgOps.cpp:1230
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1197
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2222
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:321
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:972
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:314
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:58
static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap)
Returns true if the result AffineExpr of the explicitMap is same as defaultMap.
Definition: LinalgOps.cpp:3413
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1393
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1499
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:351
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
Definition: LinalgOps.cpp:358
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:209
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:316
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:358
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:408
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_iterator result_end()
Definition: Operation.h:409
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:115
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1144
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2598
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2301
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:105
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2342
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2281
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2292
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:96
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:348
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:1982
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1985
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:2007
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2010
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.