MLIR  21.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
30 #include "mlir/IR/AffineMap.h"
31 #include "mlir/IR/Attributes.h"
32 #include "mlir/IR/Builders.h"
35 #include "mlir/IR/Matchers.h"
38 #include "mlir/IR/PatternMatch.h"
41 
42 #include "llvm/ADT/DenseMap.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/SetOperations.h"
45 #include "llvm/ADT/SmallSet.h"
46 #include "llvm/ADT/SmallVector.h"
47 #include "llvm/ADT/StringSet.h"
48 #include "llvm/ADT/TypeSwitch.h"
49 #include "llvm/Support/FormatVariadic.h"
50 #include "llvm/Support/InterleavedRange.h"
51 #include "llvm/Support/LogicalResult.h"
52 #include "llvm/Support/MathExtras.h"
53 #include "llvm/Support/raw_ostream.h"
54 #include <cassert>
55 #include <optional>
56 
57 using namespace mlir;
58 using namespace mlir::linalg;
59 
60 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
62  int64_t dim) {
63  auto type = cast<ShapedType>(v.getType());
64  if (!type.isDynamicDim(dim))
65  return builder.getIndexAttr(type.getDimSize(dim));
66 
67  return getAsOpFoldResult(
69  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
70  return builder.create<tensor::DimOp>(loc, v, dim);
71  })
72  .Case<MemRefType>([&](MemRefType t) -> Value {
73  return builder.create<memref::DimOp>(loc, v, dim);
74  }));
75 }
76 
77 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
78 /// `source`.
79 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
80  ArrayRef<OpFoldResult> offsets,
82  ArrayRef<OpFoldResult> strides) {
83  return TypeSwitch<Type, Operation *>(source.getType())
84  .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
85  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
86  strides);
87  })
88  .Case<MemRefType>([&](MemRefType type) -> Operation * {
89  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
90  strides);
91  })
92  .Default([&](Type t) -> Operation * { return nullptr; });
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Helper functions
97 //===----------------------------------------------------------------------===//
98 
100  int64_t dim) {
101  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
102  return b.createOrFold<memref::DimOp>(loc, source, dim);
103  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
104  return b.createOrFold<tensor::DimOp>(loc, source, dim);
105  llvm_unreachable("Expected MemRefType or TensorType");
106 }
107 
109  int64_t dim) {
110  auto shapedType = llvm::cast<ShapedType>(source.getType());
111  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
112  return createOrFoldDimOp(b, loc, source, dim);
113  return b.getIndexAttr(shapedType.getDimSize(dim));
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // Support for named Linalg ops defined in ods-gen.
118 //===----------------------------------------------------------------------===//
119 
122 
123 /// Fills the region of a structured operation using the provided
124 /// `regionBuilder`. The method is used by both named structured ops created by
125 /// ods-gen and by manually defined C++ ops. It is called by both builders and
126 /// parsers and creates a block with arguments corresponding to the elemental
127 /// types of `inputTypes` and `outputTypes`.
128 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
129  TypeRange inputTypes, TypeRange outputTypes,
131  RegionBuilderFn regionBuilder) {
132  SmallVector<Type, 8> argTypes;
133  SmallVector<Location, 8> argLocs;
134  for (auto containers : {inputTypes, outputTypes}) {
135  for (auto t : containers) {
136  argTypes.push_back(
137  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
138 
139  // TODO: Pass in a proper location here.
140  argLocs.push_back(opBuilder.getUnknownLoc());
141  }
142  }
143 
144  // RAII.
145  OpBuilder::InsertionGuard guard(opBuilder);
146  Block *body =
147  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
148 
149  opBuilder.setInsertionPointToStart(body);
150  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
151  regionBuilder(b, *body, attrs);
152 
153  // indexing_maps is an auto-generated method.
154 
155  // iterator_types is an auto-generated method.
156 }
157 
158 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
159 /// The result types are derived automatically if `resultTensorTypes` is none.
160 /// The body of the operation is filled using `regionBuilder`. All ods-gen
161 /// created structured operations use the method to implement their builders.
163  std::optional<TypeRange> resultTensorTypes,
164  ValueRange inputs, ValueRange outputs,
165  ArrayRef<NamedAttribute> attributes,
166  RegionBuilderFn regionBuilder) {
167  // Derive the result types if needed.
168  SmallVector<Type> derivedResultTypes =
169  resultTensorTypes.value_or(TypeRange());
170  if (!resultTensorTypes)
171  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
172  llvm::IsaPred<RankedTensorType>);
173 
174  state.addOperands(inputs);
175  state.addOperands(outputs);
176  state.addTypes(derivedResultTypes);
177 
178  state.addAttributes(attributes);
179  state.addAttribute(
180  "operandSegmentSizes",
181  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182  static_cast<int32_t>(outputs.size())}));
183 
184  // Create and fill the region of the structured operation.
185  Region &region = *state.addRegion();
186  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
187  state.attributes.getAttrs(), regionBuilder);
188 }
189 
190 static void buildMatmulOp(OpBuilder &b, OperationState &state,
191  std::optional<TypeRange> resultTensorTypes,
192  ValueRange inputs, ValueRange outputs,
193  ArrayRef<NamedAttribute> attributes,
194  RegionBuilderFn regionBuilder,
195  ArrayRef<AffineMap> indexingMaps) {
196  // Initialize indexingMaps attribute, for MatmulOp.
197  SmallVector<Attribute, 3> indexingMapsAttrVal;
198  indexingMapsAttrVal = llvm::map_to_vector(
199  MatmulOp::getDefaultIndexingMaps(b.getContext()),
200  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
201  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
202  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
203  attributes, regionBuilder);
204 }
205 
207  std::optional<TypeRange> resultTensorTypes,
208  ValueRange inputs, ValueRange outputs,
209  ArrayRef<NamedAttribute> attributes,
210  RegionBuilderFn regionBuilder,
211  ArrayRef<AffineMap> indexingMaps) {
212  // Initialize indexingMaps attribute, for BatchMatmulOp.
213  SmallVector<Attribute, 4> indexingMapsAttrVal;
214  indexingMapsAttrVal =
215  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
216  return AffineMapAttr::get(map);
217  });
218  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
219  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
220  attributes, regionBuilder);
221 }
222 
223 /// Common parsing used for both named structured ops created by ods-gen and by
224 /// manually defined C++ ops. Does not handle regions.
225 static ParseResult
227  SmallVectorImpl<Type> &inputTypes,
228  SmallVectorImpl<Type> &outputTypes,
229  bool addOperandSegmentSizes = true) {
230  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
232  outputsOperands;
233 
234  if (succeeded(parser.parseOptionalLess())) {
235  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
236  return failure();
237  }
238  attrsLoc = parser.getCurrentLocation();
239  if (parser.parseOptionalAttrDict(result.attributes))
240  return failure();
241 
242  if (succeeded(parser.parseOptionalKeyword("ins"))) {
243  if (parser.parseLParen())
244  return failure();
245 
246  inputsOperandsLoc = parser.getCurrentLocation();
247  if (parser.parseOperandList(inputsOperands) ||
248  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
249  return failure();
250  }
251 
252  if (succeeded(parser.parseOptionalKeyword("outs"))) {
253  outputsOperandsLoc = parser.getCurrentLocation();
254  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
255  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
256  return failure();
257  }
258 
259  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
260  result.operands) ||
261  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
262  result.operands))
263  return failure();
264 
265  if (addOperandSegmentSizes) {
266  // This is a bit complex because we're trying to be backward compatible with
267  // operation syntax that mix the inherent attributes and the discardable
268  // ones in the same dictionary. If the properties are used, we append the
269  // operandSegmentSizes there directly. Otherwise we append it to the
270  // discardable attributes dictionary where it is handled by the generic
271  // Operation::create(...) method.
272  if (result.propertiesAttr) {
273  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
274  attrs.append("operandSegmentSizes",
276  {static_cast<int32_t>(inputsOperands.size()),
277  static_cast<int32_t>(outputsOperands.size())}));
278  result.propertiesAttr = attrs.getDictionary(parser.getContext());
279  } else {
280  result.addAttribute("operandSegmentSizes",
282  {static_cast<int32_t>(inputsOperands.size()),
283  static_cast<int32_t>(outputsOperands.size())}));
284  }
285  }
286  if (!result.propertiesAttr) {
287  std::optional<RegisteredOperationName> info =
288  result.name.getRegisteredInfo();
289  if (info) {
290  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
291  return parser.emitError(attrsLoc)
292  << "'" << result.name.getStringRef() << "' op ";
293  })))
294  return failure();
295  }
296  }
297  return success();
298 }
299 
301  ValueRange outputs) {
302  if (!inputs.empty())
303  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
304  if (!outputs.empty())
305  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
306 }
307 
308 //===----------------------------------------------------------------------===//
309 // Specific parsing and printing for named structured ops created by ods-gen.
310 //===----------------------------------------------------------------------===//
311 
312 static ParseResult parseNamedStructuredOpRegion(
313  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
314  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
315  RegionBuilderFn regionBuilder) {
316  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
317  return parser.emitError(
318  parser.getCurrentLocation(),
319  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
320  "region expects {0} args, got {1}",
321  numRegionArgs, inputTypes.size() + outputTypes.size()));
322  }
323 
324  OpBuilder opBuilder(parser.getContext());
325  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
326  regionBuilder);
327  return success();
328 }
329 
330 static ParseResult
332  SmallVectorImpl<Type> &resultTypes) {
333  if (parser.parseOptionalArrowTypeList(resultTypes))
334  return failure();
335  return success();
336 }
337 
338 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
339  OperationState &result,
340  unsigned numRegionArgs,
341  RegionBuilderFn regionBuilder) {
342  // TODO: Enable when ods-gen supports captures.
343  SmallVector<Type, 1> inputTypes, outputTypes;
344  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
345  return failure();
346 
347  // Parse optional attributes.
348  if (parser.parseOptionalAttrDict(result.attributes))
349  return failure();
350 
351  // TODO: consider merging results parsing into region parsing.
352  // Need to wait for declarative assembly resolution to decide.
353  SmallVector<Type, 1> outputTensorsTypes;
354  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
355  return failure();
356  result.addTypes(outputTensorsTypes);
357 
358  std::unique_ptr<Region> region = std::make_unique<Region>();
359  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
360  outputTypes, result.attributes.getAttrs(),
361  regionBuilder))
362  return failure();
363  result.addRegion(std::move(region));
364 
365  return success();
366 }
367 
369  TypeRange resultTypes) {
370  if (resultTypes.empty())
371  return;
372  p.printOptionalArrowTypeList(resultTypes);
373 }
374 
376  ValueRange inputs, ValueRange outputs,
377  ArrayRef<StringRef> elidedAttrs = {}) {
378  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
379 
380  // Printing is shared with generic ops, except for the region and
381  // attributes.
382  printCommonStructuredOpParts(p, inputs, outputs);
383 
384  // Results printing.
386 
387  // Region is elided.
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // Region builder helper.
392 // TODO: Move this to a utility library.
393 // The public methods on this class are referenced directly from generated code.
394 // Helper build the unary, binary, and type conversion functions defined by the
395 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
396 // class.
397 //
398 // Implementations of the math functions must be polymorphic over numeric types,
399 // internally performing necessary casts. If the function application makes no
400 // sense, then the only recourse is to assert and return nullptr. This can be
401 // extended later if it becomes possible to fail construction of the region. The
402 // invariant should be enforced at a higher level.
403 //
404 // TODO: These helpers are currently type polymorphic over the class of integer
405 // and floating point types, but they will not internally cast within bit
406 // widths of a class (mixed precision such as i8->i32) or across classes
407 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
408 // to be handled with care and work is being considered to extend the op
409 // language to make such cases explicit. In the mean-time, violating this will
410 // fail verification, which is deemed acceptable.
411 //===----------------------------------------------------------------------===//
412 
413 namespace {
414 
415 class RegionBuilderHelper {
416 public:
417  RegionBuilderHelper(OpBuilder &builder, Block &block)
418  : builder(builder), block(block) {}
419 
420  // Build the unary functions defined by OpDSL.
421  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
422  if (!isFloatingPoint(arg))
423  llvm_unreachable("unsupported non numeric type");
424  OpBuilder::InsertionGuard g(builder);
425  builder.setInsertionPointToEnd(&block);
426  switch (unaryFn) {
427  case UnaryFn::exp:
428  return builder.create<math::ExpOp>(arg.getLoc(), arg);
429  case UnaryFn::log:
430  return builder.create<math::LogOp>(arg.getLoc(), arg);
431  case UnaryFn::abs:
432  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
433  case UnaryFn::ceil:
434  return builder.create<math::CeilOp>(arg.getLoc(), arg);
435  case UnaryFn::floor:
436  return builder.create<math::FloorOp>(arg.getLoc(), arg);
437  case UnaryFn::negf:
438  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
439  case UnaryFn::reciprocal: {
440  Attribute oneAttr = builder.getOneAttr(arg.getType());
441  auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
442  ::cast<TypedAttr>(oneAttr));
443  return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
444  }
445  case UnaryFn::round:
446  return builder.create<math::RoundOp>(arg.getLoc(), arg);
447  case UnaryFn::sqrt:
448  return builder.create<math::SqrtOp>(arg.getLoc(), arg);
449  case UnaryFn::rsqrt:
450  return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
451  case UnaryFn::square:
452  return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
453  case UnaryFn::tanh:
454  return builder.create<math::TanhOp>(arg.getLoc(), arg);
455  case UnaryFn::erf:
456  return builder.create<math::ErfOp>(arg.getLoc(), arg);
457  }
458  llvm_unreachable("unsupported unary function");
459  }
460 
461  // Build the binary functions defined by OpDSL.
462  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
463  bool allComplex = isComplex(arg0) && isComplex(arg1);
464  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
465  bool allInteger = isInteger(arg0) && isInteger(arg1);
466  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
467  arg1.getType().getIntOrFloatBitWidth() == 1;
468  if (!allComplex && !allFloatingPoint && !allInteger)
469  llvm_unreachable("unsupported non numeric type");
470  OpBuilder::InsertionGuard g(builder);
471  builder.setInsertionPointToEnd(&block);
472  switch (binaryFn) {
473  case BinaryFn::add:
474  if (allComplex)
475  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
476  if (allFloatingPoint)
477  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
478  if (allBool)
479  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
480  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
481  case BinaryFn::sub:
482  if (allComplex)
483  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
484  if (allFloatingPoint)
485  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
486  if (allBool)
487  llvm_unreachable("unsupported operation: sub with bools");
488  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
489  case BinaryFn::mul:
490  if (allComplex)
491  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
492  if (allFloatingPoint)
493  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
494  if (allBool)
495  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
496  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
497  case BinaryFn::div:
498  if (allComplex)
499  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
500  if (allFloatingPoint)
501  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
502  if (allBool)
503  llvm_unreachable("unsupported operation: div with bools");
504  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
505  case BinaryFn::div_unsigned:
506  if (!allInteger || allBool)
507  llvm_unreachable("unsupported operation: unsigned div not on uint");
508  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
509  case BinaryFn::max_signed:
510  assert(!allComplex);
511  if (allFloatingPoint)
512  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
513  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
514  case BinaryFn::min_signed:
515  assert(!allComplex);
516  if (allFloatingPoint)
517  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
518  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
519  case BinaryFn::max_unsigned:
520  assert(!allComplex);
521  if (allFloatingPoint)
522  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
523  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
524  case BinaryFn::min_unsigned:
525  assert(!allComplex);
526  if (allFloatingPoint)
527  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
528  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
529  case BinaryFn::powf:
530  assert(allFloatingPoint);
531  return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
532  }
533  llvm_unreachable("unsupported binary function");
534  }
535 
536  // Build the ternary functions defined by OpDSL.
537  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
538  Value arg2) {
539  bool headBool =
540  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
541  bool tailFloatingPoint =
542  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
543  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
544  OpBuilder::InsertionGuard g(builder);
545  builder.setInsertionPointToEnd(&block);
546  switch (ternaryFn) {
547  case TernaryFn::select:
548  if (!headBool && !(tailFloatingPoint || tailInteger))
549  llvm_unreachable("unsupported non numeric type");
550  return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
551  }
552  llvm_unreachable("unsupported ternary function");
553  }
554 
555  // Build the type functions defined by OpDSL.
556  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
557  switch (typeFn) {
558  case TypeFn::cast_signed:
559  return cast(toType, operand, false);
560  case TypeFn::cast_unsigned:
561  return cast(toType, operand, true);
562  }
563  llvm_unreachable("unsupported type conversion function");
564  }
565 
566  void yieldOutputs(ValueRange values) {
567  OpBuilder::InsertionGuard g(builder);
568  builder.setInsertionPointToEnd(&block);
569  Location loc = builder.getUnknownLoc();
570  builder.create<YieldOp>(loc, values);
571  }
572 
573  Value constant(const std::string &value) {
574  OpBuilder::InsertionGuard g(builder);
575  builder.setInsertionPointToEnd(&block);
576  Location loc = builder.getUnknownLoc();
577  Attribute valueAttr = parseAttribute(value, builder.getContext());
578  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
579  }
580 
581  Value index(int64_t dim) {
582  OpBuilder::InsertionGuard g(builder);
583  builder.setInsertionPointToEnd(&block);
584  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
585  }
586 
587  Type getIntegerType(unsigned width) {
588  return IntegerType::get(builder.getContext(), width);
589  }
590 
591  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
592  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
593 
594 private:
595  // Generates operations to cast the given operand to a specified type.
596  // If the cast cannot be performed, a warning will be issued and the
597  // operand returned as-is (which will presumably yield a verification
598  // issue downstream).
599  Value cast(Type toType, Value operand, bool isUnsignedCast) {
600  OpBuilder::InsertionGuard g(builder);
601  builder.setInsertionPointToEnd(&block);
602  auto loc = operand.getLoc();
603  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
604  }
605 
606  bool isComplex(Value value) {
607  return llvm::isa<ComplexType>(value.getType());
608  }
609  bool isFloatingPoint(Value value) {
610  return llvm::isa<FloatType>(value.getType());
611  }
612  bool isInteger(Value value) {
613  return llvm::isa<IntegerType>(value.getType());
614  }
615 
616  OpBuilder &builder;
617  Block &block;
618 };
619 
620 } // namespace
621 
622 //===----------------------------------------------------------------------===//
623 // CopyOp
624 //===----------------------------------------------------------------------===//
625 
626 namespace {
627 
628 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
630  LogicalResult matchAndRewrite(CopyOp copyOp,
631  PatternRewriter &rewriter) const override {
632  if (copyOp.getInputs() != copyOp.getOutputs())
633  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
634  if (copyOp.hasPureBufferSemantics())
635  rewriter.eraseOp(copyOp);
636  else
637  rewriter.replaceOp(copyOp, copyOp.getInputs());
638 
639  return success();
640  }
641 };
642 
643 } // namespace
644 
645 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
646  MLIRContext *context) {
647  results.add<EraseSelfCopy>(context);
648 }
649 
650 //===----------------------------------------------------------------------===//
651 // FillOp
652 //===----------------------------------------------------------------------===//
653 
654 namespace {
655 
656 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
657 ///
658 /// For such op chains, we can create new linalg.fill ops with the result
659 /// type of the tensor.expand/collapse_shape op.
660 template <typename TensorReshapeOp>
661 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
663  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
664  PatternRewriter &rewriter) const override {
665  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
666  if (!oldFill)
667  return failure();
668 
669  Location loc = oldFill.getLoc();
670  TensorReshapeOp newInit;
671  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
672 
673  newInit = rewriter.create<TensorReshapeOp>(
674  loc, reshapeOp.getResultType(), oldFill.output(),
675  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
676  reshapeOp.getStaticOutputShape());
677  } else {
678  newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
679  oldFill.output(),
680  reshapeOp.getReassociation());
681  }
682  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
683  ValueRange{newInit});
684  return success();
685  }
686 };
687 
688 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
689 /// filling value are the same.
690 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
692 
693  LogicalResult matchAndRewrite(tensor::PadOp padOp,
694  PatternRewriter &rewriter) const override {
695  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
696  if (!fillOp)
697  return failure();
698 
699  // We can only fold if the padding value is the same as the original
700  // filling value.
701  Value padValue = padOp.getConstantPaddingValue();
702  if (!padValue || fillOp.value() != padValue)
703  return failure();
704 
705  ReifiedRankedShapedTypeDims reifiedShape;
706  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
707  return rewriter.notifyMatchFailure(
708  padOp, "failed to reify tensor.pad op result shape");
709 
710  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
711  padOp.getLoc(), reifiedShape.front(),
712  padOp.getResultType().getElementType());
713  Value replacement =
714  rewriter
715  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
716  ValueRange{emptyTensor})
717  .getResult(0);
718  if (replacement.getType() != padOp.getResultType()) {
719  replacement = rewriter.create<tensor::CastOp>(
720  fillOp.getLoc(), padOp.getResultType(), replacement);
721  }
722  rewriter.replaceOp(padOp, replacement);
723  return success();
724  }
725 };
726 
727 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
728 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
729 /// filling value are the same.
730 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
732 
733  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
734  PatternRewriter &rewriter) const override {
735  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
736  if (!srcPadOp)
737  return failure();
738 
739  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
740  return failure();
741 
742  // Walk back the tensor.insert_slice chain and find the first destination
743  // value at the start of the chain.
744  Value firstDest = insertOp.getDest();
745  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
746  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
747  return failure();
748 
749  // Make sure the range of values accessed are disjoint. Without this, we
750  // cannot fold tensor.pad away.
751  bool disjoint = false;
752  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
753  // If the dimension has dynamic offset/size, we cannot guarantee
754  // disjoint. So just skip it.
755  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
756  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
757  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
758  continue;
759 
760  // Get the range start and end, inclusively for both.
761  int64_t prevStart = prevOp.getStaticOffset(i);
762  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
763  prevOp.getStaticStride(i);
764  int64_t nextStart = insertOp.getStaticOffset(i);
765  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
766  insertOp.getStaticStride(i);
767  if (prevEnd < nextStart || nextEnd < prevStart) {
768  disjoint = true;
769  break;
770  }
771  }
772 
773  if (!disjoint)
774  break;
775  firstDest = prevOp.getDest();
776  }
777 
778  // Check whether the first destination is a fill op. For overlapped cases,
779  // this also cannot be true.
780  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
781  if (!dstFillOp)
782  return failure();
783 
784  // We can only fold if the padding value is the same as the original
785  // filling value.
786  Value padValue = srcPadOp.getConstantPaddingValue();
787  if (!padValue || dstFillOp.value() != padValue)
788  return failure();
789 
790  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
791  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
792 
793  Location loc = insertOp.getLoc();
794  MLIRContext *context = getContext();
795 
796  AffineExpr sym0, sym1;
797  bindSymbols(context, sym0, sym1);
798  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
799 
800  // Calculate the new offsets for the insert. It should be the old offsets
801  // plus low padding sizes.
802  SmallVector<OpFoldResult, 4> newOffsets;
803  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
804  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
805  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
806  }
807 
808  RankedTensorType srcPadType = srcPadOp.getSourceType();
810  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
811  if (srcPadType.isDynamicDim(i)) {
812  newSizes.push_back(
813  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
814  .getResult());
815  } else {
816  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
817  }
818  }
819 
820  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
821  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
822  newSizes, insertOp.getMixedStrides());
823  return success();
824  }
825 };
826 
827 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
828 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
829 public:
831 
832  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
833  PatternRewriter &rewriter) const override {
834  // See if tensor input of tensor.extract op is the result of a linalg.fill
835  // op.
836  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
837  if (!fillOp)
838  return failure();
839 
840  // Get scalar input operand of linalg.fill op.
841  Value extractedScalar = fillOp.getInputs()[0];
842 
843  // Replace tensor.extract op with scalar value used to fill the tensor.
844  rewriter.replaceOp(extractOp, extractedScalar);
845  return success();
846  }
847 };
848 
849 /// Folds pack(fill) into a single fill op if
850 /// 1. The pack op does not have padding value, or
851 /// 2. The filled value and padding value are the same.
852 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
853  linalg::PackOp packOp) {
854  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
855  if (!fillOp)
856  return failure();
857 
858  if (auto paddingValue = packOp.getPaddingValue())
859  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
860  return failure();
861 
862  Value packOpDest = packOp.getDest();
863  if (!packOpDest.hasOneUse())
864  return failure();
865 
866  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
867  packOp.getDest());
868 }
869 
870 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
871 struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
872 public:
873  FoldFillWithPack(MLIRContext *context)
874  : OpRewritePattern<linalg::PackOp>(context) {}
875 
876  LogicalResult matchAndRewrite(linalg::PackOp packOp,
877  PatternRewriter &rewriter) const override {
878  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
879  if (failed(fillOp))
880  return failure();
881  rewriter.replaceOp(packOp, fillOp.value().result());
882  return success();
883  }
884 };
885 
886 /// Fold fill with copy.
887 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
889 
890  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
891  PatternRewriter &rewriter) const override {
892  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
893  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
894  fillOp.getInputs(),
895  copyOp.getOutputs());
896  return success();
897  }
898  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
899  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
900  fillOp.getOutputs());
901  return success();
902  }
903  return failure();
904  }
905 };
906 
907 /// Fold fill with transpose.
908 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
910 
911  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
912  PatternRewriter &rewriter) const override {
913  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
914  rewriter.replaceOpWithNewOp<FillOp>(
915  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
916  transposeOp.getDpsInitOperand(0)->get());
917  return success();
918  }
919  return failure();
920  }
921 };
922 
923 /// Fold a concat with all elements being fills of the same value
924 /// into a fill of the concat result shape.
925 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
927 
928  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
929  PatternRewriter &rewriter) const override {
930  auto concatOperands = concatOp.getInputs();
931  if (concatOperands.empty()) {
932  return failure();
933  }
934 
935  auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
936  if (!firstFillOp) {
937  return failure();
938  }
939  // Prefetch the fill value.
940  OpFoldResult firstFillVal =
941  getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
942  // Collect all the outs values for the fill operations.
943  SmallVector<Value> allOuts;
944  allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
945 
946  auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
947  auto fillOp = v.getDefiningOp<linalg::FillOp>();
948  if (!fillOp) {
949  return false;
950  }
951 
952  OpFoldResult fillVal =
953  getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
954  if (fillVal != firstFillVal)
955  return false;
956 
957  allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
958  return true;
959  };
960  if (!llvm::all_of(concatOperands.drop_front(),
961  isDefinedByCompatibleFillOp)) {
962  return rewriter.notifyMatchFailure(
963  concatOp, "not all operands are defined by a compatible fill op");
964  }
965 
966  Value outsConcat = rewriter.create<tensor::ConcatOp>(
967  concatOp.getLoc(), concatOp.getDim(), allOuts);
968  rewriter.replaceOpWithNewOp<linalg::FillOp>(
969  concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
970  return success();
971  }
972 };
973 
974 } // namespace
975 
976 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
977  MLIRContext *context) {
978  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
979  FoldFillWithPack, FoldFillWithPad,
980  FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
981  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
982  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
983 }
984 
985 //===----------------------------------------------------------------------===//
986 // GenericOp
987 //===----------------------------------------------------------------------===//
988 
989 static void buildGenericRegion(
990  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
991  ValueRange outputs,
992  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
993  SmallVector<Type, 4> blockArgTypes;
994  SmallVector<Location, 4> blockArgLocs;
995  for (ValueRange container : {inputs, outputs}) {
996  for (Value v : container) {
997  Type t = v.getType();
998  blockArgTypes.push_back(
999  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1000  blockArgLocs.push_back(v.getLoc());
1001  }
1002  }
1003 
1004  OpBuilder::InsertionGuard guard(builder);
1005  Block *bodyBlock =
1006  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1007  bodyBuild(builder, loc, bodyBlock->getArguments());
1008 }
1009 
1010 void GenericOp::getAsmBlockArgumentNames(Region &region,
1011  OpAsmSetValueNameFn setNameFn) {
1012  for (Value v : getRegionInputArgs())
1013  setNameFn(v, "in");
1014  for (Value v : getRegionOutputArgs())
1015  setNameFn(v, "out");
1016 }
1017 
1018 void GenericOp::build(
1019  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1020  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1021  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1022  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1023  ArrayRef<NamedAttribute> attributes) {
1024  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1025  iteratorTypes, doc, libraryCall);
1026  result.addAttributes(attributes);
1027  if (bodyBuild)
1028  buildGenericRegion(builder, result.location, *result.regions.front(),
1029  inputs, outputs, bodyBuild);
1030 }
1031 
1032 void GenericOp::build(
1033  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1034  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1035  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1036  StringRef libraryCall,
1037  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1038  ArrayRef<NamedAttribute> attributes) {
1039  build(builder, result, resultTensorTypes, inputs, outputs,
1040  builder.getAffineMapArrayAttr(indexingMaps),
1041  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1042  iteratorTypes,
1043  [&](utils::IteratorType iter) -> mlir::Attribute {
1044  return IteratorTypeAttr::get(builder.getContext(), iter);
1045  }))),
1046  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1047  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1048  bodyBuild, attributes);
1049 }
1050 
1051 void GenericOp::build(
1052  OpBuilder &builder, OperationState &result, ValueRange inputs,
1053  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1054  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1055  StringRef libraryCall,
1056  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1057  ArrayRef<NamedAttribute> attributes) {
1058  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1059  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1060 }
1061 
1062 void GenericOp::build(
1063  OpBuilder &builder, OperationState &result, ValueRange inputs,
1064  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1065  ArrayRef<utils::IteratorType> iteratorTypes,
1066  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1067  ArrayRef<NamedAttribute> attributes) {
1068  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1069  /*doc=*/"",
1070  /*libraryCall=*/"", bodyBuild, attributes);
1071 }
1072 
1073 void GenericOp::build(
1074  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1075  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1076  ArrayRef<utils::IteratorType> iteratorTypes,
1077  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1078  ArrayRef<NamedAttribute> attributes) {
1079  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1080  iteratorTypes,
1081  /*doc=*/"",
1082  /*libraryCall=*/"", bodyBuild, attributes);
1083 }
1084 
1085 void GenericOp::print(OpAsmPrinter &p) {
1086  p << " ";
1087 
1088  // Print extra attributes.
1089  auto genericAttrNames = linalgTraitAttrNames();
1090 
1091  llvm::StringSet<> genericAttrNamesSet;
1092  genericAttrNamesSet.insert_range(genericAttrNames);
1093  SmallVector<NamedAttribute, 8> genericAttrs;
1094  for (auto attr : (*this)->getAttrs()) {
1095  if (attr.getName() == getIteratorTypesAttrName()) {
1096  auto iteratorTypes =
1097  llvm::cast<ArrayAttr>(attr.getValue())
1098  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1099  // Convert IteratorType enums into the string representation. This is
1100  // needed, because tests still use the old format when 'iterator_types'
1101  // attribute is represented as an array of strings.
1102  // TODO: Remove this conversion once tests are fixed.
1103  SmallVector<Attribute> iteratorTypeNames =
1104  llvm::to_vector(llvm::map_range(
1105  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1106  return StringAttr::get(getContext(), stringifyIteratorType(t));
1107  }));
1108 
1109  genericAttrs.emplace_back(
1110  getIteratorTypesAttrName(),
1111  ArrayAttr::get(getContext(), iteratorTypeNames));
1112  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1113  genericAttrs.push_back(attr);
1114  }
1115  }
1116  if (!genericAttrs.empty()) {
1117  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1118  p << genericDictAttr;
1119  }
1120 
1121  // Printing is shared with named ops, except for the region and attributes
1122  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1123 
1124  genericAttrNames.push_back("operandSegmentSizes");
1125  genericAttrNamesSet.insert(genericAttrNames.back());
1126 
1127  bool hasExtraAttrs = false;
1128  for (NamedAttribute n : (*this)->getAttrs()) {
1129  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1130  break;
1131  }
1132  if (hasExtraAttrs) {
1133  p << " attrs = ";
1134  p.printOptionalAttrDict((*this)->getAttrs(),
1135  /*elidedAttrs=*/genericAttrNames);
1136  }
1137 
1138  // Print region.
1139  if (!getRegion().empty()) {
1140  p << ' ';
1141  p.printRegion(getRegion());
1142  }
1143 
1144  // Print results.
1145  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1146 }
1147 
1148 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1149  DictionaryAttr dictAttr;
1150  // Parse the core linalg traits that must check into a dictAttr.
1151  // The name is unimportant as we will overwrite result.attributes.
1152  // The core linalg traits must contain the information necessary to pass the
1153  // verifier.
1154  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1155  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1156  return failure();
1157  result.attributes.assign(dictAttr.getValue().begin(),
1158  dictAttr.getValue().end());
1159 
1160  // Convert array of string into an array of IteratorType enums. This is
1161  // needed, because tests still use the old format when 'iterator_types'
1162  // attribute is represented as an array of strings.
1163  // TODO: Remove this conversion once tests are fixed.
1164  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1165  result.attributes.get(getIteratorTypesAttrName(result.name)));
1166  if (!iteratorTypes) {
1167  return parser.emitError(attributeLocation)
1168  << "expected " << getIteratorTypesAttrName(result.name)
1169  << " array attribute";
1170  }
1171 
1172  SmallVector<Attribute> iteratorTypeAttrs;
1173 
1174  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1175  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1176  if (!maybeIteratorType.has_value())
1177  return parser.emitError(parser.getCurrentLocation())
1178  << "unexpected iterator_type (" << s << ")";
1179 
1180  iteratorTypeAttrs.push_back(
1181  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1182  }
1183  result.attributes.set(getIteratorTypesAttrName(result.name),
1184  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1185 
1186  // Parsing is shared with named ops, except for the region.
1187  SmallVector<Type, 1> inputTypes, outputTypes;
1188  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1189  return failure();
1190 
1191  // Optional attributes may be added.
1192  if (succeeded(parser.parseOptionalKeyword("attrs")))
1193  if (failed(parser.parseEqual()) ||
1194  failed(parser.parseOptionalAttrDict(result.attributes)))
1195  return failure();
1196 
1197  std::unique_ptr<Region> region = std::make_unique<Region>();
1198  if (parser.parseRegion(*region, {}))
1199  return failure();
1200  result.addRegion(std::move(region));
1201 
1202  // Generic ops may specify that a subset of its outputs are tensors. Such
1203  // outputs are specified in the result type.
1204  // TODO: may need to move output parsing before region parsing.
1205  // Need to wait for declarative assembly resolution to decide.
1206  SmallVector<Type, 1> outputTensorsTypes;
1207  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1208  return failure();
1209  result.addTypes(outputTensorsTypes);
1210 
1211  return success();
1212 }
1213 
1216  &effects,
1217  LinalgOp linalgOp) {
1218  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1219  if (!llvm::isa<MemRefType>(operand.getType()))
1220  continue;
1221  effects.emplace_back(
1222  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1223  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1224  }
1225 
1226  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1227  if (!llvm::isa<MemRefType>(operand.get().getType()))
1228  continue;
1229  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1230  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1231  /*effectOnFullRegion=*/true,
1233  }
1234  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1235  /*effectOnFullRegion=*/true,
1237  }
1238 }
1239 
1240 void GenericOp::getEffects(
1242  &effects) {
1243  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1244 }
1245 
1247 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1248  // Operands with value semantics are speculatable, while operands with memory
1249  // semantics are not.
1250  if (!linalgOp.hasPureTensorSemantics())
1252  // The body of the op can still have speculation in its region.
1254 }
1255 
1256 Speculation::Speculatability GenericOp::getSpeculatability() {
1257  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1258 }
1259 
1260 LogicalResult GenericOp::verify() { return success(); }
1261 
1262 namespace {
1263 
1264 /// Remove any linalg operation (on tensors) that are just copying
1265 /// the values from inputs to the results. Requirements are
1266 /// 1) All iterator types are parallel
1267 /// 2) The body contains just a yield operation with the yielded values being
1268 /// the arguments corresponding to the operands.
1269 template <typename OpTy>
1270 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1272 
1273  LogicalResult matchAndRewrite(OpTy linalgOp,
1274  PatternRewriter &rewriter) const override {
1275  // All indexing maps must be equal. It follows that they are permutations.
1276  if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1277  return failure();
1278 
1279  // Check that the body of the linalg operation is just a linalg.yield
1280  // operation.
1281  Block &body = linalgOp->getRegion(0).front();
1282  if (!llvm::hasSingleElement(body))
1283  return failure();
1284  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1285  if (!yieldOp)
1286  return failure();
1287 
1288  // In the buffer case, we need to check exact buffer equality.
1289  if (linalgOp.hasPureBufferSemantics()) {
1290  if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1291  linalgOp.getDpsInputOperand(0)->get() ==
1292  linalgOp.getDpsInitOperand(0)->get()) {
1293  rewriter.eraseOp(linalgOp);
1294  return success();
1295  }
1296  return failure();
1297  }
1298 
1299  // Mixed semantics is not supported yet.
1300  if (!linalgOp.hasPureTensorSemantics())
1301  return failure();
1302 
1303  // Get the argument number of the returned values. That is the operand
1304  // number to use for replacing uses of this operation.
1305  SmallVector<Value> returnedArgs;
1306  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1307  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1308  if (!yieldArg || yieldArg.getOwner() != &body)
1309  return failure();
1310  unsigned argumentNumber = yieldArg.getArgNumber();
1311  Value returnedArg = linalgOp->getOperand(argumentNumber);
1312  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1313  // The input can have a different type than the result, e.g. a dynamic
1314  // input dimension can be turned into a static output dimension.
1315  Type returnType = returnedArg.getType();
1316  if (returnType != resultType) {
1317  // Distinguish between sparse conversion or dense tensor casting.
1318  // TODO: unify the two ops?
1319  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1321  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1322  linalgOp.getLoc(), resultType, returnedArg);
1323  else {
1324  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1325  resultType))
1326  return failure();
1327  returnedArg = rewriter.create<tensor::CastOp>(
1328  linalgOp.getLoc(), resultType, returnedArg);
1329  }
1330  }
1331  returnedArgs.push_back(returnedArg);
1332  }
1333 
1334  if (returnedArgs.size() != linalgOp->getNumResults())
1335  return failure();
1336  rewriter.replaceOp(linalgOp, returnedArgs);
1337  return success();
1338  }
1339 };
1340 
1341 } // namespace
1342 
1343 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1344  MLIRContext *context) {
1345  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1346 }
1347 
1348 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1349  return memref::foldMemRefCast(*this);
1350 }
1351 
1352 //===----------------------------------------------------------------------===//
1353 // MapOp
1354 //===----------------------------------------------------------------------===//
1355 
1356 static ParseResult parseDstStyleOp(
1357  OpAsmParser &parser, OperationState &result,
1358  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1359  nullptr) {
1360  // Parse `ins` and `outs`.
1361  SmallVector<Type, 4> inputTypes, outputTypes;
1362  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1363  /*addOperandSegmentSizes=*/false))
1364  return failure();
1365 
1366  // Add result types.
1367  for (Type outputType : outputTypes) {
1368  if (llvm::isa<RankedTensorType>(outputType))
1369  result.addTypes(outputType);
1370  }
1371 
1372  // Parse required attributes.
1373  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1374  return failure();
1375 
1376  // Parse optional attributes.
1377  if (parser.parseOptionalAttrDict(result.attributes))
1378  return failure();
1379  return success();
1380 }
1381 
1382 void MapOp::getAsmBlockArgumentNames(Region &region,
1383  OpAsmSetValueNameFn setNameFn) {
1384  for (Value v : getRegionInputArgs())
1385  setNameFn(v, "in");
1386 }
1387 
1388 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1389  if (!getResults().empty())
1390  setNameFn(getResults().front(), "mapped");
1391 }
1392 
1393 void MapOp::build(
1394  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1395  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1396  ArrayRef<NamedAttribute> attributes) {
1397  build(builder, result, TypeRange{}, inputs, init);
1398  result.addAttributes(attributes);
1399 
1400  // Add output types for `RankedTensorType` output arguments.
1401  Type initType = init.getType();
1402  if (llvm::isa<RankedTensorType>(initType))
1403  result.addTypes(initType);
1404 
1405  if (bodyBuild)
1406  buildGenericRegion(builder, result.location, *result.regions.front(),
1407  inputs, /*outputs=*/{}, bodyBuild);
1408 }
1409 
1410 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1411  const OperationName &payloadOpName,
1412  const NamedAttrList &payloadOpAttrs,
1413  ArrayRef<Value> operands,
1414  bool initFirst = false) {
1415  OpBuilder b(parser.getContext());
1416  Region *body = result.addRegion();
1417  Block &block = body->emplaceBlock();
1418  b.setInsertionPointToStart(&block);
1419  for (auto &operand : operands) {
1420  block.addArgument(
1421  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1422  b.getUnknownLoc());
1423  }
1424  SmallVector<Value> payloadOpOperands;
1425  // If initFirst flag is enabled, we consider init as the first position of
1426  // payload operands.
1427  if (initFirst) {
1428  payloadOpOperands.push_back(block.getArguments().back());
1429  for (const auto &arg : block.getArguments().drop_back())
1430  payloadOpOperands.push_back(arg);
1431  } else {
1432  payloadOpOperands = {block.getArguments().begin(),
1433  block.getArguments().end()};
1434  }
1435 
1436  Operation *payloadOp = b.create(
1437  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1438  payloadOpOperands,
1439  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1440  .getElementType()},
1441  payloadOpAttrs);
1442  b.create<YieldOp>(result.location, payloadOp->getResults());
1443 }
1444 
1445 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1446  std::optional<OperationName> payloadOpName;
1447  NamedAttrList payloadOpAttrs;
1448  if (succeeded(parser.parseOptionalLBrace())) {
1449  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1450  if (failed(operationName))
1451  return failure();
1452  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1453  return failure();
1454  payloadOpName = operationName.value();
1455  if (parser.parseRBrace())
1456  return failure();
1457  }
1458 
1459  if (parseDstStyleOp(parser, result))
1460  return failure();
1461 
1462  if (payloadOpName.has_value()) {
1463  if (!result.operands.empty())
1464  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1465  payloadOpAttrs,
1466  ArrayRef(result.operands).drop_back());
1467  else
1468  result.addRegion();
1469  } else {
1471  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1472  /*allowType=*/true, /*allowAttrs=*/true)) {
1473  return failure();
1474  }
1475  Region *body = result.addRegion();
1476  if (parser.parseRegion(*body, regionArgs))
1477  return failure();
1478  }
1479  return success();
1480 }
1481 
1482 // Retrieve the operation from the body, if it is the only one (except
1483 // yield) and if it gets the same amount of arguments as the body does.
1484 // If initFirst flag is enabled, we check that init takes the first position in
1485 // operands of payload.
1486 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1487  if (body->getOperations().size() != 2)
1488  return nullptr;
1489  Operation &payload = body->getOperations().front();
1490  assert(isa<YieldOp>(body->getOperations().back()));
1491 
1492  if (payload.getNumOperands() == 0 ||
1493  payload.getNumOperands() != body->getNumArguments())
1494  return nullptr;
1495  if (initFirst) {
1496  // check init
1497  if (payload.getOperands().back() != body->getArgument(0))
1498  return nullptr;
1499  // check rest
1500  for (const auto &[operand, bbArg] :
1501  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1502  if (bbArg != operand)
1503  return nullptr;
1504  }
1505  } else {
1506  for (const auto &[operand, bbArg] :
1507  llvm::zip(payload.getOperands(), body->getArguments())) {
1508  if (bbArg != operand)
1509  return nullptr;
1510  }
1511  }
1512  return &payload;
1513 }
1514 
1515 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1516  SmallVector<StringRef> elidedAttrs;
1517  std::string attrToElide;
1518  p << " { " << payloadOp->getName().getStringRef();
1519  for (const auto &attr : payloadOp->getAttrs()) {
1520  auto fastAttr =
1521  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1522  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1523  attrToElide = attr.getName().str();
1524  elidedAttrs.push_back(attrToElide);
1525  break;
1526  }
1527  }
1528  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1529  p << " }";
1530 }
1531 
1532 void MapOp::print(OpAsmPrinter &p) {
1533  Block *mapper = getBody();
1534  Operation *payloadOp = findPayloadOp(mapper);
1535  if (payloadOp) {
1536  printShortForm(p, payloadOp);
1537  }
1538 
1539  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1540  p.printOptionalAttrDict((*this)->getAttrs());
1541 
1542  if (!payloadOp) {
1543  // Print region if the payload op was not detected.
1544  p.increaseIndent();
1545  p.printNewline();
1546  p << "(";
1547  llvm::interleaveComma(mapper->getArguments(), p,
1548  [&](auto arg) { p.printRegionArgument(arg); });
1549  p << ") ";
1550 
1551  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1552  p.decreaseIndent();
1553  }
1554 }
1555 
1556 LogicalResult MapOp::verify() {
1557  auto *bodyBlock = getBody();
1558  auto blockArgs = bodyBlock->getArguments();
1559 
1560  // Checks if the number of `inputs` match the arity of the `mapper` region.
1561  if (getInputs().size() != blockArgs.size())
1562  return emitOpError() << "expects number of operands to match the arity of "
1563  "mapper, but got: "
1564  << getInputs().size() << " and " << blockArgs.size();
1565 
1566  // The parameters of mapper should all match the element type of inputs.
1567  for (const auto &[bbArgType, inputArg] :
1568  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1569  auto inputElemType =
1570  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1571  if (bbArgType != inputElemType) {
1572  return emitOpError() << "expected element type of input " << inputElemType
1573  << " to match bbArg type " << bbArgType;
1574  }
1575  }
1576 
1577  // The shape of each input must match the shape of the output.
1578  auto outputShape = getInit().getType().getShape();
1579  for (Type inputArgType : TypeRange{getInputs()}) {
1580  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1581  if (inputElemShape != outputShape) {
1582  return emitOpError() << "expected shape of input (" << inputElemShape
1583  << ") to match shape of output (" << outputShape
1584  << ")";
1585  }
1586  }
1587 
1588  return success();
1589 }
1590 
1591 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1592  int64_t rank = getInit().getType().getRank();
1593  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1594 }
1595 
1596 ArrayAttr MapOp::getIndexingMaps() {
1597  Builder builder(getContext());
1598  int64_t rank = getInit().getType().getRank();
1599  int64_t numIndexingMaps = getOperands().size();
1601  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1602 }
1603 
1604 void MapOp::getEffects(
1606  &effects) {
1607  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1608 }
1609 
1610 Speculation::Speculatability MapOp::getSpeculatability() {
1611  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1612 }
1613 
1614 //===----------------------------------------------------------------------===//
1615 // ReduceOp
1616 //===----------------------------------------------------------------------===//
1617 
1618 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1619  OpAsmSetValueNameFn setNameFn) {
1620  for (Value v : getRegionInputArgs())
1621  setNameFn(v, "in");
1622  for (Value v : getRegionOutputArgs())
1623  setNameFn(v, "init");
1624 }
1625 
1626 void ReduceOp::getAsmResultNames(
1627  function_ref<void(Value, StringRef)> setNameFn) {
1628  if (!getResults().empty())
1629  setNameFn(getResults().front(), "reduced");
1630 }
1631 
1632 void ReduceOp::build(
1633  OpBuilder &builder, OperationState &result, ValueRange inputs,
1634  ValueRange inits, ArrayRef<int64_t> dimensions,
1635  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1636  ArrayRef<NamedAttribute> attributes) {
1637  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1638  result.addAttributes(attributes);
1639 
1640  // Add output types for `RankedTensorType` output arguments.
1641  for (Value init : inits) {
1642  Type initType = init.getType();
1643  if (llvm::isa<RankedTensorType>(initType))
1644  result.addTypes(initType);
1645  }
1646 
1647  if (bodyBuild)
1648  buildGenericRegion(builder, result.location, *result.regions.front(),
1649  inputs, inits, bodyBuild);
1650 }
1651 
1652 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1653  int64_t inputRank =
1654  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1655  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1656  utils::IteratorType::parallel);
1657  for (int64_t reductionDim : getDimensions())
1658  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1659  return iteratorTypes;
1660 }
1661 
1662 ArrayAttr ReduceOp::getIndexingMaps() {
1663  int64_t inputRank =
1664  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1665  SmallVector<AffineMap> affineMaps(
1666  getNumDpsInputs(),
1668  AffineMap resultMap =
1670  .dropResults(getDimensions());
1671  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1672  affineMaps.push_back(resultMap);
1673  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1674 }
1675 
1676 void ReduceOp::getEffects(
1678  &effects) {
1679  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1680 }
1681 
1682 Speculation::Speculatability ReduceOp::getSpeculatability() {
1683  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1684 }
1685 
1686 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1687  NamedAttrList &attributes,
1688  StringRef attributeName) {
1689  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1690  return failure();
1691 
1692  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1693  return success();
1694 }
1695 
1696 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1697  std::optional<OperationName> payloadOpName;
1698  NamedAttrList payloadOpAttrs;
1699  if (succeeded(parser.parseOptionalLBrace())) {
1700  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1701  if (failed(operationName))
1702  return failure();
1703  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1704  return failure();
1705  payloadOpName = operationName.value();
1706  if (parser.parseRBrace())
1707  return failure();
1708  }
1709 
1710  if (parseDstStyleOp(
1711  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1712  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1713  }))
1714  return failure();
1715 
1716  if (payloadOpName.has_value()) {
1717  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1718  ArrayRef(result.operands), /*initFirst=*/true);
1719  } else {
1721  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1722  /*allowType=*/true, /*allowAttrs=*/true)) {
1723  return failure();
1724  }
1725 
1726  Region *body = result.addRegion();
1727  if (parser.parseRegion(*body, regionArgs))
1728  return failure();
1729  }
1730 
1731  return success();
1732 }
1733 
1734 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1735  ArrayRef<int64_t> attributeValue) {
1736  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1737 }
1738 
1739 void ReduceOp::print(OpAsmPrinter &p) {
1740  Block *mapper = getBody();
1741  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1742  if (payloadOp) {
1743  printShortForm(p, payloadOp);
1744  }
1745 
1746  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1747  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1748  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1749  if (!payloadOp) {
1750  // Print region if the payload op was not detected.
1751  p.increaseIndent();
1752  p.printNewline();
1753  p << "(";
1754  llvm::interleaveComma(mapper->getArguments(), p,
1755  [&](auto arg) { p.printRegionArgument(arg); });
1756  p << ") ";
1757 
1758  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1759  p.decreaseIndent();
1760  }
1761 }
1762 
1763 LogicalResult ReduceOp::verify() {
1764  ArrayRef<int64_t> dimensionsRef = getDimensions();
1765 
1766  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1767  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1768  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1769  return emitOpError() << "expects all inputs to have the same shapes. "
1770  "Shape at input-index "
1771  << i
1772  << " is not equal to the shape at input-index 0.";
1773  }
1774  }
1775  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1776  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1777  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1778  return emitOpError() << "expects all outputs to have the same shapes. "
1779  "Shape at output-index "
1780  << i
1781  << " is not equal to the shape at output-index 0.";
1782  }
1783  }
1784  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1785  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1786 
1787  DenseSet<int64_t> dimensionsToReduce;
1788  for (int64_t dimension : dimensionsRef) {
1789  if (dimension < 0 || dimension >= inputType.getRank()) {
1790  return emitOpError()
1791  << "dimensions for reduction should be in the range [0, "
1792  << inputType.getRank() - 1 << "].";
1793  }
1794  dimensionsToReduce.insert(dimension);
1795  }
1796 
1797  auto inputDims = inputType.getShape();
1798  auto initDims = initType.getShape();
1799 
1800  // Input dimensions that will be left after the reduction.
1801  SmallVector<int64_t> reducedInputDims;
1802  for (const auto &en : llvm::enumerate(inputDims)) {
1803  if (!dimensionsToReduce.count(en.index()))
1804  reducedInputDims.push_back(en.value());
1805  }
1806 
1807  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1808  return emitOpError() << "number of dimensions after reduction "
1809  << reducedInputDims.size()
1810  << " doesn't match the init rank "
1811  << initType.getRank();
1812  }
1813 
1814  if (reducedInputDims != initDims)
1815  return emitOpError() << "init dimensions [" << initDims
1816  << "] doesn't match input dimensions after reduction ["
1817  << reducedInputDims << "]";
1818 
1819  Block *block = getBody();
1820  if (block->getNumArguments() != this->getNumOperands())
1821  return emitOpError()
1822  << "mismatching number of operands and block arguments";
1823 
1824  // Check that the first block arguments match the element type of the inputs.
1825  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1826  Type inputElementType =
1827  llvm::cast<ShapedType>(input.getType()).getElementType();
1828  if (inputElementType != bbArg.getType())
1829  return emitOpError()
1830  << "input element type " << inputElementType
1831  << " does not match corresponding block argument type "
1832  << bbArg.getType();
1833  }
1834 
1835  // Check that the last block arguments match the element type of the outputs.
1836  for (auto [output, bbArg] : llvm::zip(
1837  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1838  auto outputElementType =
1839  llvm::cast<ShapedType>(output.getType()).getElementType();
1840  if (outputElementType != bbArg.getType())
1841  return emitOpError()
1842  << "output element type " << outputElementType
1843  << " does not match corresponding block argument type "
1844  << bbArg.getType();
1845  }
1846  return success();
1847 }
1848 
1849 //===----------------------------------------------------------------------===//
1850 // TransposeOp
1851 //===----------------------------------------------------------------------===//
1852 
1853 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1854  Region &region, ValueRange inputs,
1855  ValueRange outputs) {
1856  buildGenericRegion(builder, loc, region, inputs, outputs,
1857  [](OpBuilder &b, Location loc, ValueRange args) {
1858  if (!args.empty())
1859  b.create<linalg::YieldOp>(loc, args[0]);
1860  });
1861 }
1862 
1863 void TransposeOp::build(::mlir::OpBuilder &builder,
1864  ::mlir::OperationState &result, Value input, Value init,
1865  DenseI64ArrayAttr permutation,
1866  ArrayRef<NamedAttribute> attributes) {
1867  result.addOperands(input);
1868  result.addOperands(init);
1869  result.addAttribute(getPermutationAttrName(result.name), permutation);
1870  result.addAttributes(attributes);
1871 
1872  // Add output types for `RankedTensorType` output arguments.
1873  Type initType = init.getType();
1874  if (llvm::isa<RankedTensorType>(initType))
1875  result.addTypes(initType);
1876 
1877  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1878  init);
1879 }
1880 
1881 void TransposeOp::build(::mlir::OpBuilder &builder,
1882  ::mlir::OperationState &result, Value input, Value init,
1883  ArrayRef<int64_t> permutation,
1884  ArrayRef<NamedAttribute> attributes) {
1885  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1886  attributes);
1887 }
1888 
1889 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1890  if (failed(parseDstStyleOp(
1891  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1892  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1893  })))
1894  return failure();
1895 
1896  OpBuilder builder(parser.getContext());
1897  buildIdentityRegion(builder, result.location, *result.addRegion(),
1898  /*inputs=*/result.operands,
1899  /*outputs=*/{});
1900  return success();
1901 }
1902 
1903 void TransposeOp::getAsmResultNames(
1904  function_ref<void(Value, StringRef)> setNameFn) {
1905  if (!getResults().empty())
1906  setNameFn(getResults().front(), "transposed");
1907 }
1908 
1910  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1911  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1912  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1913 }
1914 
1915 LogicalResult TransposeOp::verify() {
1916  ArrayRef<int64_t> permutationRef = getPermutation();
1917 
1918  if (!isPermutationVector(permutationRef))
1919  return emitOpError("permutation is not valid");
1920 
1921  auto inputType = getInput().getType();
1922  auto initType = getInit().getType();
1923 
1924  int64_t rank = inputType.getRank();
1925 
1926  if (rank != initType.getRank())
1927  return emitOpError() << "input rank " << rank
1928  << " does not match init rank " << initType.getRank();
1929 
1930  if (rank != static_cast<int64_t>(permutationRef.size()))
1931  return emitOpError() << "size of permutation " << permutationRef.size()
1932  << " does not match the argument rank " << rank;
1933 
1934  auto inputDims = inputType.getShape();
1935  auto initDims = initType.getShape();
1936 
1937  for (int64_t i = 0; i < rank; ++i) {
1938  int64_t inputDim = inputDims[permutationRef[i]];
1939  int64_t initDim = initDims[i];
1940 
1941  if (inputDim != initDim) {
1942  return emitOpError() << "dim(result, " << i << ") = " << initDim
1943  << " doesn't match dim(input, permutation[" << i
1944  << "]) = " << inputDim;
1945  }
1946  }
1947 
1948  return success();
1949 }
1950 
1951 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1952  int64_t rank = getInit().getType().getRank();
1953  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1954 }
1955 
1956 ArrayAttr TransposeOp::getIndexingMaps() {
1957  Builder builder(getContext());
1958  int64_t rank = getInit().getType().getRank();
1959  return builder.getAffineMapArrayAttr(
1961  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1962  builder.getMultiDimIdentityMap(rank)});
1963 }
1964 
1965 void TransposeOp::getEffects(
1967  &effects) {
1968  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1969 }
1970 
1971 Speculation::Speculatability TransposeOp::getSpeculatability() {
1972  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1973 }
1974 
1975 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1977  // Only the tensor type is supported.
1978  if (!isa<TensorType>(getInput().getType()))
1979  return failure();
1980 
1981  // Single dimension transpose.
1982  if (getPermutation().size() == 0) {
1983  result.push_back(getInput());
1984  return success();
1985  }
1986  // Identity permutation.
1987  if (isIdentityPermutation(getPermutation())) {
1988  result.push_back(getInput());
1989  return success();
1990  }
1991 
1992  return failure();
1993 }
1994 
1995 /// Fold transpose with transpose.
1996 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1998 
1999  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2000  PatternRewriter &rewriter) const override {
2001  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2002  if (!defTransposeOp)
2003  return failure();
2004  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2005  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2006  SmallVector<int64_t> foldedPerms;
2007  foldedPerms.reserve(perms.size());
2008  for (int64_t perm : perms)
2009  foldedPerms.push_back(defPerms[perm]);
2010 
2011  rewriter.replaceOpWithNewOp<TransposeOp>(
2012  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2013  foldedPerms);
2014  return success();
2015  }
2016 };
2017 
2018 /// This pattern canonicalize transpose by swapping the order of
2019 /// broadcast and transpose:
2020 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2021 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2023 
2024  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2025  PatternRewriter &rewriter) const override {
2026  Value input = transposeOp.getInput();
2027  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2028  if (!input.hasOneUse() || !broadcastOp)
2029  return failure();
2030 
2031  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2032  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2033 
2034  // Get new perms and new dimensions.
2035  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2036  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2037  SmallVector<int64_t> resultDimensions;
2038  unsigned dimensionSize = dimensions.size();
2039  for (unsigned i = 0; i < dimensionSize; ++i)
2040  resultDimensions.push_back(invertPerm[dimensions[i]]);
2041 
2042  // Create transpose result.
2043  Value broadcastInput = broadcastOp.getInput();
2044  Location loc = transposeOp.getLoc();
2045  MLIRContext *ctx = transposeOp.getContext();
2047  auto broadcastInputTy =
2048  mlir::cast<RankedTensorType>(broadcastInput.getType());
2049  unsigned inputRank = broadcastInputTy.getRank();
2050  for (unsigned i = 0; i < inputRank; ++i) {
2051  if (broadcastInputTy.isDynamicDim(i)) {
2052  dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2053  ->getResult(0));
2054  } else {
2055  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2056  broadcastInputTy.getDimSize(i)));
2057  }
2058  }
2059  SmallVector<OpFoldResult> transposeResultShapes =
2060  applyPermutation(dims, resultPerms);
2061  Value transposeInit = rewriter.create<tensor::EmptyOp>(
2062  transposeOp.getLoc(), transposeResultShapes,
2063  broadcastInputTy.getElementType());
2064 
2065  // Create broadcast(transpose(input)).
2066  Value transposeResult =
2067  rewriter
2068  .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2069  resultPerms)
2070  ->getResult(0);
2071  rewriter.replaceOpWithNewOp<BroadcastOp>(
2072  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2073  return success();
2074  }
2075 };
2076 
2077 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2078  MLIRContext *context) {
2080 }
2081 
2082 //===----------------------------------------------------------------------===//
2083 // BroadcastOp
2084 //===----------------------------------------------------------------------===//
2085 
2086 void BroadcastOp::build(::mlir::OpBuilder &builder,
2087  ::mlir::OperationState &result, Value input, Value init,
2088  DenseI64ArrayAttr dimensions,
2089  ArrayRef<NamedAttribute> attributes) {
2090  result.addOperands(input);
2091  result.addOperands(init);
2092  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2093  result.addAttributes(attributes);
2094 
2095  // Add output types for `RankedTensorType` output arguments.
2096  Type initType = init.getType();
2097  if (llvm::isa<RankedTensorType>(initType))
2098  result.addTypes(initType);
2099 
2100  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2101  init);
2102 }
2103 
2104 void BroadcastOp::build(::mlir::OpBuilder &builder,
2105  ::mlir::OperationState &result, Value input, Value init,
2106  ArrayRef<int64_t> dimensions,
2107  ArrayRef<NamedAttribute> attributes) {
2108  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2109  attributes);
2110 }
2111 
2112 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2113  if (failed(parseDstStyleOp(
2114  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2115  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2116  })))
2117  return failure();
2118 
2119  OpBuilder builder(parser.getContext());
2120  buildIdentityRegion(builder, result.location, *result.addRegion(),
2121  /*inputs=*/result.operands,
2122  /*outputs=*/{});
2123  return success();
2124 }
2125 
2126 void BroadcastOp::getAsmResultNames(
2127  function_ref<void(Value, StringRef)> setNameFn) {
2128  if (!getResults().empty())
2129  setNameFn(getResults().front(), "broadcasted");
2130 }
2131 
2133  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2134  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2135  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2136 }
2137 
2138 LogicalResult BroadcastOp::verify() {
2139  ArrayRef<int64_t> dimensionsRef = getDimensions();
2140 
2141  auto inputType = getInput().getType();
2142  auto initType = getInit().getType();
2143 
2144  int64_t inputRank = inputType.getRank();
2145  int64_t initRank = initType.getRank();
2146 
2147  auto inputShape = inputType.getShape();
2148  auto initShape = initType.getShape();
2149 
2150  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2151  return emitOpError() << "input rank plus added dimensions does not "
2152  "match init rank. input rank: "
2153  << inputRank
2154  << ", dimensions size: " << dimensionsRef.size()
2155  << ", init rank: " << initRank;
2156 
2157  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2158  if (dim < 0 || dim >= initRank)
2159  return emitOpError() << "dimension " << idx
2160  << " is out of range. expected range: [0, "
2161  << initRank - 1 << "], got: " << dim;
2162  }
2163 
2164  // Mapping from input dims to init dims.
2165  SmallVector<int64_t> dimMap;
2166  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2167  if (!llvm::is_contained(dimensionsRef, dim))
2168  dimMap.push_back(dim);
2169  }
2170 
2171  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2172  // This dimensions is mapped from the input. Init and input dims should
2173  // match.
2174  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2175  return emitOpError() << "input dim " << inputDimIdx
2176  << " should match init dim " << initDimIdx
2177  << ". input: " << inputShape[inputDimIdx]
2178  << ", init: " << initShape[initDimIdx];
2179  }
2180 
2181  return success();
2182 }
2183 
2184 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2185  int64_t rank = getInit().getType().getRank();
2186  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2187 }
2188 
2189 ArrayAttr BroadcastOp::getIndexingMaps() {
2190  Builder builder(getContext());
2191  int64_t rank = getInit().getType().getRank();
2192  return builder.getAffineMapArrayAttr(
2193  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2194  builder.getMultiDimIdentityMap(rank)});
2195 }
2196 
2197 void BroadcastOp::getEffects(
2199  &effects) {
2200  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2201 }
2202 
2203 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2204  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2205 }
2206 
2207 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2208  MLIRContext *context) {
2209  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2210 }
2211 
2212 //===----------------------------------------------------------------------===//
2213 // YieldOp
2214 //===----------------------------------------------------------------------===//
2215 
2217  if (getNumOperands() > 0)
2218  p << ' ' << getOperands();
2219  p.printOptionalAttrDict((*this)->getAttrs());
2220  if (getNumOperands() > 0)
2221  p << " : " << getOperandTypes();
2222 }
2223 
2224 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2226  SmallVector<Type, 2> types;
2227  SMLoc loc = parser.getCurrentLocation();
2228  return failure(parser.parseOperandList(opInfo) ||
2229  parser.parseOptionalAttrDict(result.attributes) ||
2230  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2231  parser.resolveOperands(opInfo, types, loc, result.operands));
2232 }
2233 
2234 // Check the operand number and types must match the element types of the
2235 // LinalgOp interface's shaped operands.
2236 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2237  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2238  return op.emitOpError("expected number of yield values (")
2239  << op.getNumOperands()
2240  << ") to match the number of inits / outs operands of the enclosing "
2241  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2242 
2243  for (OpOperand &opOperand : op->getOpOperands()) {
2244  OpOperand *outputOperand =
2245  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2246  Type elementType = outputOperand->get().getType();
2247  if (isa<MemRefType, RankedTensorType>(elementType))
2248  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2249  if (opOperand.get().getType() != elementType)
2250  return op.emitOpError("type of yield operand ")
2251  << (opOperand.getOperandNumber() + 1) << " ("
2252  << opOperand.get().getType() << ") doesn't match "
2253  << "the element type of the enclosing linalg.generic op ("
2254  << elementType << ")";
2255  }
2256  return success();
2257 }
2258 
2259 LogicalResult linalg::YieldOp::verify() {
2260  auto *parentOp = (*this)->getParentOp();
2261  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2262  return emitOpError("expected single non-empty parent region");
2263 
2264  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2265  return verifyYield(*this, linalgOp);
2266 
2267  return emitOpError("expected parent op with LinalgOp interface");
2268 }
2269 
2270 //===----------------------------------------------------------------------===//
2271 // IndexOp
2272 //===----------------------------------------------------------------------===//
2273 
2274 LogicalResult IndexOp::verify() {
2275  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2276  if (!linalgOp)
2277  return emitOpError("expected parent op with LinalgOp interface");
2278  if (linalgOp.getNumLoops() <= getDim())
2279  return emitOpError("expected dim (")
2280  << getDim() << ") to be lower than the number of loops ("
2281  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2282  return success();
2283 }
2284 
2285 OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2286  auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2287  // Bail out if `linalg.index` does not have a proper parent yet at this
2288  // point, e.g., when calling `createOrFold` during IR construction in
2289  // `genericOp::build`.
2290  if (!linalgOp)
2291  return OpFoldResult{};
2292 
2293  // Index of unit dims is always 0.
2294  SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2295  uint64_t dim = getDim();
2296  assert(dim < loopBounds.size() && "Dim is out of bounds");
2297  if (loopBounds[dim] == 1)
2299 
2300  return OpFoldResult{};
2301 }
2302 
2303 /////// Operations corresponding to library calls defined with Tablegen ////////
2304 
2305 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2306 
2307 #define GET_OP_CLASSES
2308 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2309 
2310 #define GET_OP_CLASSES
2311 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2312 #define GET_OP_CLASSES
2313 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2314 
2315 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2316  unsigned rank,
2317  MLIRContext *context) {
2318  if (maybeMap)
2319  return *maybeMap;
2320  if (rank == 0)
2321  return AffineMap::get(context);
2322  return AffineMap::getMultiDimIdentityMap(rank, context);
2323 }
2324 
2326 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2327  MLIRContext *context) {
2329  res.reserve(num);
2330  for (unsigned i = 0; i < num; ++i)
2331  res.push_back(getAffineDimExpr(startIdx++, context));
2332  return res;
2333 }
2334 
2337  auto rangeA = llvm::make_range(a.begin(), a.end());
2338  auto rangeB = llvm::make_range(b.begin(), b.end());
2339  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2340  return llvm::to_vector<4>(concatRanges);
2341 }
2342 
2343 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2344  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2345  ss << "view";
2346  for (auto size : memref.getShape())
2347  if (size < 0)
2348  ss << "sx";
2349  else
2350  ss << size << "x";
2351  if (failed(appendMangledType(ss, memref.getElementType())))
2352  return failure();
2353  if (auto as = memref.getMemorySpace()) {
2354  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2355  ss << "as" << attr.getInt();
2356  else
2357  return failure();
2358  }
2359  return success();
2360  }
2361  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2362  ss << "vector";
2363  llvm::interleave(
2364  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2365  if (failed(appendMangledType(ss, vec.getElementType())))
2366  return failure();
2367  return success();
2368  }
2369  if (t.isSignlessIntOrIndexOrFloat()) {
2370  ss << t;
2371  return success();
2372  }
2373  return failure();
2374 }
2375 
2377  assert(isa<LinalgOp>(op));
2378  std::string name(op->getName().getStringRef().str());
2379  std::string fun = "";
2380  for (NamedAttribute kv : op->getAttrs()) {
2381  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2382  fun = stringifyEnum(ufa.getValue()).str() + "_";
2383  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2384  fun = stringifyEnum(bfa.getValue()).str() + "_";
2385  }
2386  }
2387  name.reserve(128);
2388  std::replace(name.begin(), name.end(), '.', '_');
2389  llvm::raw_string_ostream ss(name);
2390  ss << "_" << fun;
2391  for (Type t : op->getOperandTypes()) {
2392  if (failed(appendMangledType(ss, t)))
2393  return std::string();
2394  ss << "_";
2395  }
2396  name.pop_back();
2397  return name;
2398 }
2399 
2400 //===----------------------------------------------------------------------===//
2401 // Canonicalizers and Folders.
2402 //===----------------------------------------------------------------------===//
2403 
2404 namespace {
2405 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2407 
2408  LogicalResult matchAndRewrite(LinalgOp op,
2409  PatternRewriter &rewriter) const override {
2410  for (OpOperand &opOperand : op->getOpOperands()) {
2411  // Linalg "inputs" may be either tensor or memref type.
2412  // tensor<0xelt_type> is a convention that may not always mean
2413  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2414  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2415  if (!mt)
2416  continue;
2417  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2418  rewriter.eraseOp(op);
2419  return success();
2420  }
2421  }
2422  return failure();
2423  }
2424 };
2425 
2426 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2427 /// result that is more static than the linalg op.
2428 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2430 
2431  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2432  PatternRewriter &rewriter) const override {
2433  if (!tensor::canFoldIntoProducerOp(castOp))
2434  return failure();
2435 
2436  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2437  if (!linalgOp)
2438  return failure();
2439 
2440  // Cast can be in conditionally reachable region, if which case folding will
2441  // generate invalid code. Only conservatively fold ops in same block for
2442  // now.
2443  if (castOp->getBlock() != linalgOp->getBlock())
2444  return failure();
2445 
2446  OpBuilder::InsertionGuard guard(rewriter);
2447  rewriter.setInsertionPoint(linalgOp);
2448 
2449  Location loc = linalgOp.getLoc();
2450  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2451  unsigned resultNumber = resultValue.getResultNumber();
2452  auto resultType =
2453  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2454  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2455  // going from a more dynamic shape to a less dynamic shape. If the producer
2456  // for this cast, i.e. producer of the out operand, is also an operation
2457  // that folds with tensor.cast consumer (like this pattern), the cast will
2458  // continue to propagate as far up the stack as it can go.
2459  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2460  Value newOperand =
2461  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2462  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2463  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2464  linalgOp.getDpsInits().end());
2465  outputOperands[resultNumber] = newOperand;
2466  newOperands.append(outputOperands.begin(), outputOperands.end());
2467 
2468  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2469  linalgOp->result_type_end());
2470  resultTypes[resultNumber] = resultType;
2471  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2472 
2473  // Create a tensor.cast operation back to the original type.
2474  Value castBack = rewriter.create<tensor::CastOp>(
2475  loc, resultValue.getType(), newOp->getResult(resultNumber));
2476 
2477  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2478  results[resultNumber] = castBack;
2479  rewriter.replaceOp(linalgOp, results);
2480  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2481  return success();
2482  }
2483 };
2484 
2485 /// For each of the operand in `operands` this function maps the static sizes of
2486 /// dimensions to their affine dim expressions.
2487 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2488  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2489  for (OpOperand &opOperand : operands) {
2490  if (linalgOp.isScalar(&opOperand))
2491  continue;
2492  Value src = opOperand.get();
2493  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2494  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2495 
2496  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2497  // `tensor.cast` operation and source of the cast operation has a static
2498  // shape, then assign it to the `sourceShape`.
2499  auto *parentOp = src.getDefiningOp();
2500  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2501  if (parentOp) {
2502  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2503  Value castSource = castOp.getSource();
2504  auto castSourceType =
2505  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2506  if (castSourceType && castSourceType.hasStaticShape())
2507  sourceShape = castSourceType.getShape();
2508  }
2509  }
2510 
2511  // If the source shape's dimension has a static shape, map the affine dim
2512  // expression to the known static size.
2513  for (unsigned i = 0; i < sourceShape.size(); i++) {
2514  if (sourceType.isDynamicDim(i))
2515  continue;
2516  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2517  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2518  }
2519  }
2520 }
2521 
2522 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2523 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2524 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2525 /// change then `changeNeeded` is false and same operand is added in the
2526 /// `newOperands` list.
2527 static void createNewOperandWithStaticSizes(
2528  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2529  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2530  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2531  bool &changeNeeded) {
2532  Value src = opOperand->get();
2533  newOperands.push_back(src);
2534  if (linalgOp.isScalar(opOperand))
2535  return;
2536  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2537  Type resultType = sourceType;
2538  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2539  resultTypes.push_back(resultType);
2540  return;
2541  }
2542  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2543  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2544  SmallVector<int64_t> newShape;
2545  // If operand is updated with new shape, `newOperandNeeded` will be
2546  // true.
2547  bool newOperandNeeded = false;
2548  for (unsigned i = 0; i < sourceShape.size(); i++) {
2549  int64_t dimShape = sourceShape[i];
2550  AffineExpr dimExpr = sourceMap.getResult(i);
2551  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2552  newShape.push_back(dimShape);
2553  continue;
2554  }
2555  // Dimension has a dynamic shape and corresponding affine dim
2556  // expression is present in the map. So assign the size for the
2557  // given affine dim expression to the dimension.
2558  newShape.push_back(affineExprToSize[dimExpr]);
2559  newOperandNeeded = true;
2560  }
2561  resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2562  sourceType.getEncoding());
2563  if (newOperandNeeded) {
2564  changeNeeded = true;
2565  // Get the new operand value given its size and element type by
2566  // casting it.
2567  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2568  unsigned index = opOperand->getOperandNumber();
2569  newOperands[index] = newOperand;
2570  }
2571  if (linalgOp.isDpsInit(opOperand))
2572  resultTypes.push_back(resultType);
2573 }
2574 
2575 /// Static shapes for the operands can be inferred if any one of the operands
2576 /// have a static shape. This can be done by referring to the affine dim
2577 /// expressions for the operand.
2578 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2580 
2581  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2582  PatternRewriter &rewriter) const override {
2583  if (!linalgOp.hasPureTensorSemantics())
2584  return failure();
2585 
2586  // Maps must be projected permutations.
2587  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2588  return !map.isProjectedPermutation();
2589  }))
2590  return failure();
2591 
2592  // Maps affine dim expressions to the static size of that dimension.
2593  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2594  Location loc = linalgOp.getLoc();
2595 
2596  // For each of the affine dim expression, check if the size is known. If
2597  // known add that in the map.
2598  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2599 
2600  SmallVector<Value> newOperands;
2601  SmallVector<Type> resultTypes;
2602 
2603  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2604  // change in their types.
2605  bool changeNeeded = false;
2606  newOperands.reserve(linalgOp->getNumOperands());
2607  resultTypes.reserve(linalgOp.getNumDpsInits());
2608 
2609  // Iterate over all the operands and update the static sizes.
2610  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2611  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2612  affineExprToSize, linalgOp, newOperands,
2613  resultTypes, changeNeeded);
2614  }
2615 
2616  // If the generic op has all the required static information, no
2617  // canonicalization needed.
2618  if (!changeNeeded)
2619  return failure();
2620 
2621  // Clone op.
2622  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2623  SmallVector<Value> replacements;
2624  replacements.reserve(newOp->getNumResults());
2625  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2626  Value newResult = std::get<1>(it);
2627  Value oldResult = std::get<0>(it);
2628  Type newType = newResult.getType();
2629  Type oldType = oldResult.getType();
2630  replacements.push_back(
2631  (newType != oldType)
2632  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2633  : newResult);
2634  }
2635  rewriter.replaceOp(linalgOp, replacements);
2636  return success();
2637  }
2638 };
2639 
2640 } // namespace
2641 
2642 // All named ops canonicalizers and folders are auto-generated in the
2643 // .cpp.inc.
2644 
2645 //===----------------------------------------------------------------------===//
2646 // SoftmaxOp
2647 //===----------------------------------------------------------------------===//
2648 
2649 LogicalResult SoftmaxOp::verify() {
2650  ShapedType inputType = getInputOperandType();
2651  ShapedType outputType = getOutputOperandType();
2652 
2653  ArrayRef<int64_t> inputShape = inputType.getShape();
2654  ArrayRef<int64_t> outputShape = outputType.getShape();
2655  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2656  return emitOpError("incompatible output shape");
2657 
2658  int64_t inputRank = getInputOperandRank();
2659  int64_t dimension = getDimension();
2660  if ((dimension < 0) || (dimension >= inputRank))
2661  return emitOpError("incorrect dimension specified");
2662 
2663  return success();
2664 }
2665 
2666 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2667  int64_t operandRank = getInputOperandRank();
2668  SmallVector<Range> loopBounds(operandRank);
2669  Location loc = getLoc();
2670  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2671  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2672  Value source = getInput();
2673  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2674  loopBounds[dim].offset = zero;
2675  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2676  loopBounds[dim].stride = one;
2677  }
2678  return loopBounds;
2679 }
2680 
2681 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2682  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2683  utils::IteratorType::parallel);
2684  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2685  return iteratorTypes;
2686 }
2687 
2688 FailureOr<TilingResult>
2690  ArrayRef<OpFoldResult> offsets,
2691  ArrayRef<OpFoldResult> sizes) {
2692  int64_t rank = getInputOperandRank();
2693  auto oneAttr = builder.getI64IntegerAttr(1);
2694  SmallVector<OpFoldResult> strides(rank, oneAttr);
2695  SmallVector<Value> tiledOperands;
2696  Operation *inputSlice =
2697  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2698  if (!inputSlice) {
2699  return emitOpError("failed to compute input slice");
2700  }
2701  tiledOperands.emplace_back(inputSlice->getResult(0));
2702  Operation *outputSlice =
2703  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2704  if (!outputSlice) {
2705  return emitOpError("failed to compute output slice");
2706  }
2707  tiledOperands.emplace_back(outputSlice->getResult(0));
2708 
2709  SmallVector<Type, 4> resultTypes;
2710  if (hasPureTensorSemantics())
2711  resultTypes.push_back(tiledOperands[1].getType());
2712  Operation *tiledOp =
2713  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2714 
2715  return TilingResult{
2716  {tiledOp},
2717  SmallVector<Value>(tiledOp->getResults()),
2718  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2719 }
2720 
2721 LogicalResult SoftmaxOp::getResultTilePosition(
2722  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2723  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2724  SmallVector<OpFoldResult> &resultSizes) {
2725  if (resultNumber == 0) {
2726  resultOffsets.assign(offsets.begin(), offsets.end());
2727  resultSizes.assign(sizes.begin(), sizes.end());
2728  return success();
2729  }
2730  return failure();
2731 }
2732 
2733 // cast(dynamic) -> static.
2734 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2735  return memref::foldMemRefCast(*this);
2736 }
2737 
2738 LogicalResult
2740  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2742  Location loc = getOperation()->getLoc();
2743  IRRewriter rewriter(b);
2744  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2745  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2746  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2747  if (!outputShapedType.isDynamicDim(dim)) {
2748  // Static dim: Return IntegerAttr.
2749  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2750  } else {
2751  // Dynamic dim: Return Value.
2752  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2753  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2754  }
2755  }
2756  reifiedReturnShapes.emplace_back(std::move(shapes));
2757  return success();
2758 }
2759 
2760 void SoftmaxOp::getEffects(
2762  &effects) {
2763  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2764  if (!llvm::isa<MemRefType>(operand.getType()))
2765  continue;
2766  effects.emplace_back(MemoryEffects::Read::get(),
2767  &getOperation()->getOpOperand(index), /*stage=*/0,
2768  /*effectOnFullRegion=*/true,
2770  }
2771 
2772  for (OpOperand &operand : getDpsInitsMutable()) {
2773  if (!llvm::isa<MemRefType>(operand.get().getType()))
2774  continue;
2775  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2776  /*effectOnFullRegion=*/true,
2778  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2779  /*effectOnFullRegion=*/true,
2781  }
2782 }
2783 
2784 // Helper functions for softmax decomposition.
2785 // @{
2786 
2787 // Helper function to produce the iterator types (reduction or parallel) and
2788 // affine maps for the iterators used in the decomposition of softmax.
2789 // This method creates:
2790 // If allParallel == true:
2791 // - iterator type: {parallel, ..., parallel}
2792 // - affine maps:
2793 // -- identity with inputRank dimensions.
2794 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2795 // where N == inputRank.
2796 //
2797 // If allParallel == false:
2798 // - iterator type at dim(i) == parallel for i != \p dim and
2799 // dim(dim) == reduction.
2800 // - affine map:
2801 // -- identity with inputRank dimensions.
2802 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2803 // where N == inputRank.
2804 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2806  int64_t dim, bool allParallel = false) {
2807  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2808  utils::IteratorType::parallel);
2809  if (!allParallel)
2810  iteratorTypes[dim] = utils::IteratorType::reduction;
2811  MLIRContext *ctxt = builder.getContext();
2812  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2813  SmallVector<AffineExpr, 2> affineExprs;
2814  for (int i = 0; i < inputRank; i++) {
2815  if (i != dim)
2816  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2817  }
2818  auto reductionMap =
2819  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2820  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2821  return std::make_tuple(iteratorTypes, indexingMaps);
2822 }
2823 
2824 // Helper function to produce a linalg.generic that computes a reduction on
2825 // dimension \p dim with the operation type \p T.
2826 template <typename T>
2827 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2828  int64_t dim) {
2829  auto inputType = cast<ShapedType>(input.getType());
2830  ArrayRef<int64_t> inputShape = inputType.getShape();
2831  int64_t inputRank = inputShape.size();
2832  auto [iteratorTypes, indexingMaps] =
2833  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2834  assert(indexingMaps.size() == 2 &&
2835  "We should have two maps: 1 for the input, 1 for the output");
2836  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2837 
2838  auto genericOp = builder.create<linalg::GenericOp>(
2839  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2840  [&](OpBuilder &b, Location loc, ValueRange args) {
2841  Value result = b.create<T>(loc, args[0], args[1]);
2842  b.create<linalg::YieldOp>(loc, result);
2843  });
2844  return genericOp.getResult(0);
2845 }
2846 
2847 /// Produce a linalg generic that computes the second step of the softmax
2848 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2849 /// on dimension \p dim.
2850 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2851  Value max, Value output, int64_t dim) {
2852  auto inputType = cast<ShapedType>(input.getType());
2853  ArrayRef<int64_t> inputShape = inputType.getShape();
2854  int64_t inputRank = inputShape.size();
2855  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2856  builder, inputRank, dim, /*allParallel=*/true);
2857  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2858  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2859  // Add the affine map for the output argument.
2860  indexingMaps.push_back(indexingMaps[0]);
2861  auto genericOp = builder.create<linalg::GenericOp>(
2862  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2863  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2864  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2865  Value result = b.create<math::ExpOp>(loc, diff);
2866  b.create<linalg::YieldOp>(loc, result);
2867  });
2868  return genericOp.getResult(0);
2869 }
2870 
2871 /// Produce a linalg generic that computes the final step of the softmax
2872 /// decomposition.
2873 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2874 /// yield n / d
2875 /// }
2876 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2877  Value denominator, Value output, int64_t dim) {
2878  auto inputType = cast<ShapedType>(numerator.getType());
2879  ArrayRef<int64_t> inputShape = inputType.getShape();
2880  int64_t inputRank = inputShape.size();
2881  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2882  builder, inputRank, dim, /*allParallel=*/true);
2883  assert(indexingMaps.size() == 2 &&
2884  "We should have one map for each input (2)");
2885  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2886  // Add the affine map for the output tensor.
2887  indexingMaps.push_back(indexingMaps[0]);
2888  auto genericOp = builder.create<linalg::GenericOp>(
2889  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2890  indexingMaps, iteratorTypes,
2891  [&](OpBuilder &b, Location loc, ValueRange args) {
2892  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2893  b.create<linalg::YieldOp>(loc, result);
2894  });
2895  return genericOp.getResult(0);
2896 }
2897 // @} End helper functions for softmax decomposition.
2898 
2899 /// Given an N-dimensional tensor x, this method converts
2900 /// softmax(x) to the following sequence of operations:
2901 ///
2902 /// 1. Compute the max of x along dimension d. This results
2903 /// in a N-1 dimensional tensor m.
2904 /// m = max(x, dim = d)
2905 ///
2906 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2907 /// a N dimensional tensor z.
2908 /// z = exp(x - m)
2909 ///
2910 /// 3. Compute the sum of z along dimension d. This results in
2911 /// a N-1 dimensional tensor l.
2912 /// l = sum(z, dim = d)
2913 ///
2914 /// 4. Divide z and l. This gives the N-dimensional softmax.
2915 /// softmax = z / l
2916 ///
2917 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2918  OpBuilder::InsertionGuard guard(b);
2919  b.setInsertionPoint(*this);
2920  Location loc = getLoc();
2921  Value input = getInput();
2922  ShapedType inputType = getInputOperandType();
2923  Type elementType = inputType.getElementType();
2924  int64_t reductionDim = getDimension();
2925  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2926  Value output = getOutput();
2927  dims.erase(dims.begin() + reductionDim);
2928  // Step 1: Compute max along dim.
2929  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2930  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
2931  elementType, b, loc,
2932  /*useOnlyFiniteValue=*/true);
2933  Value neutralForMaxFInit =
2934  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2935  .result();
2936  Value max =
2937  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2938 
2939  // Step 2: Subtract max from input and exponentiate.
2940  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2941 
2942  // Step 3: Compute sum along dim.
2943  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2944  b, loc, /*useOnlyFiniteValue=*/true);
2945  Value zeroInit =
2946  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2947  Value denominator =
2948  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2949 
2950  // Step 4: Compute softmax.
2951  Value result =
2952  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2953  return SmallVector<Value>{result};
2954 }
2955 
2956 //===----------------------------------------------------------------------===//
2957 // WinogradFilterTransformOp
2958 //===----------------------------------------------------------------------===//
2959 
2960 LogicalResult WinogradFilterTransformOp::verify() {
2961  auto filterType = cast<ShapedType>(getFilter().getType());
2962  ArrayRef<int64_t> filterShape = filterType.getShape();
2963  int64_t filterH = filterShape[getFilterHDim()];
2964  int64_t filterW = filterShape[getFilterWDim()];
2965  int64_t r = getR();
2966  int64_t m = getM();
2967 
2968  if (filterH != r && filterH != 1)
2969  return emitOpError("expect filter height either equals to r or 1");
2970  if (filterW != r && filterW != 1)
2971  return emitOpError("expect filter width either equals to r or 1");
2972  if (filterH == 1 && filterW == 1)
2973  return emitOpError("expect either filter height or width equals to r");
2974 
2975  SmallVector<int64_t> expectedOutputShape;
2976  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2977  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2978  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2979  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2980 
2981  auto outputType = cast<ShapedType>(getOutput().getType());
2982  ArrayRef<int64_t> outputShape = outputType.getShape();
2983  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2984  return emitOpError("the output shape is not expected");
2985  }
2986  return success();
2987 }
2988 
2990 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
2991  Location loc = getLoc();
2992  IntegerAttr zeroAttr = builder.getIndexAttr(0);
2993  IntegerAttr oneAttr = builder.getIndexAttr(1);
2994  Value filter = getFilter();
2995  int64_t filterRank = getFilterOperandRank();
2996  SmallVector<Range> loopBounds(filterRank);
2997  for (unsigned dim = 0; dim < filterRank; ++dim) {
2998  loopBounds[dim].offset = zeroAttr;
2999  loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3000  loopBounds[dim].stride = oneAttr;
3001  }
3002  return loopBounds;
3003 }
3004 
3006 WinogradFilterTransformOp::getLoopIteratorTypes() {
3007  int64_t filterRank = getFilterOperandRank();
3008  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3009  utils::IteratorType::parallel);
3010  return iteratorTypes;
3011 }
3012 
3014  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3015  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3016  SmallVector<OpFoldResult> &resultSizes) {
3017  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3018  ShapedType filterType = getFilterOperandType();
3019  ArrayRef<int64_t> filterShape = filterType.getShape();
3020  int64_t filterH = filterShape[getFilterHDim()];
3021  int64_t filterW = filterShape[getFilterWDim()];
3022  int64_t m = getM();
3023  int64_t r = getR();
3024  int64_t alpha = m + r - 1;
3025  int64_t alphaH = filterH != 1 ? alpha : 1;
3026  int64_t alphaW = filterW != 1 ? alpha : 1;
3027  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3028  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3029 
3030  resultOffsets.append(
3031  {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3032  resultSizes.append(
3033  {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3034 
3035  return success();
3036 }
3037 
3038 /// Implement tiling for winograd_filter_transform
3039 /// The input of winograd_filter_transform is (F, KH, KW, C).
3040 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3041 /// Users can specify the tile sizes of F and C.
3042 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3043 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3045  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3046  ArrayRef<OpFoldResult> sizes) {
3047  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3048  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3049  ShapedType filterType = getFilterOperandType();
3050  ArrayRef<int64_t> filterShape = filterType.getShape();
3051  int64_t filterH = filterShape[getFilterHDim()];
3052  int64_t filterW = filterShape[getFilterWDim()];
3053  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3054  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3055  SmallVector<Value> tiledOperands;
3056  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3057 
3058  sliceOffsets.append(
3059  {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3060  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3061  sizes[getFilterCDim()]});
3062  int64_t filterRank = getFilterOperandRank();
3063  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3064  Location loc = getLoc();
3065  auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3066  loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3067  tiledOperands.emplace_back(filterSlice);
3068 
3069  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3070  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3071  resultSizes)))
3072  return failure();
3073 
3074  int64_t outputRank = getOutputOperandRank();
3075  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3076  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3077  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3078  tiledOperands.emplace_back(outputSlice);
3079 
3080  SmallVector<Type> resultTypes;
3081  resultTypes.push_back(tiledOperands[1].getType());
3082  Operation *tiledOp =
3083  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3084 
3085  return TilingResult{
3086  {tiledOp},
3087  SmallVector<Value>(tiledOp->getResults()),
3088  llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3089 }
3090 
3091 //===----------------------------------------------------------------------===//
3092 // WinogradInputTransformOp
3093 //===----------------------------------------------------------------------===//
3094 
3095 LogicalResult WinogradInputTransformOp::verify() {
3096  auto inputType = cast<ShapedType>(getInput().getType());
3097  ArrayRef<int64_t> inputShape = inputType.getShape();
3098  int64_t inputH = inputShape[getInputHDim()];
3099  int64_t inputW = inputShape[getInputWDim()];
3100  int m = getM();
3101  int r = getR();
3102  int64_t tileSize = m + r - 1;
3103 
3104  auto outputType = cast<ShapedType>(getOutput().getType());
3105  ArrayRef<int64_t> outputShape = outputType.getShape();
3106  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3107  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3108 
3109  SmallVector<int64_t> expectedOutputShape(6, inputH);
3110  if (ShapedType::isDynamic(inputH)) {
3111  expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3112  expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3113  } else {
3114  expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3115  expectedOutputShape[getOutputTileHDim()] =
3116  leftTransform ? (inputH - (r - 1)) / m : inputH;
3117  }
3118  if (ShapedType::isDynamic(inputW)) {
3119  expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3120  expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3121  } else {
3122  expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3123  expectedOutputShape[getOutputTileWDim()] =
3124  rightTransform ? (inputW - (r - 1)) / m : inputW;
3125  }
3126  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3127  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3128 
3129  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3130  return emitOpError("the output shape is not expected");
3131  }
3132  return success();
3133 }
3134 
3136 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3137  Location loc = getLoc();
3138  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3139  IntegerAttr oneAttr = builder.getIndexAttr(1);
3140  Value output = getOutput();
3141  int64_t outputRank = getOutputOperandRank();
3142  SmallVector<Range> loopBounds(outputRank);
3143  for (unsigned dim = 0; dim < outputRank; ++dim) {
3144  loopBounds[dim].offset = zeroAttr;
3145  // alphaH, alphaW, tileH, tileW, N, C
3146  loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3147  loopBounds[dim].stride = oneAttr;
3148  }
3149  return loopBounds;
3150 }
3151 
3153 WinogradInputTransformOp::getLoopIteratorTypes() {
3154  int64_t outputRank = getOutputOperandRank();
3155  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3156  utils::IteratorType::parallel);
3157  return iteratorTypes;
3158 }
3159 
3161  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3162  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3163  SmallVector<OpFoldResult> &resultSizes) {
3164  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3165  ShapedType outputType = getOutputOperandType();
3166  ArrayRef<int64_t> outputShape = outputType.getShape();
3167  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3168  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3169 
3170  int64_t m = getM();
3171  int64_t r = getR();
3172  int64_t alpha = m + r - 1;
3173  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3174  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3175 
3176  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3177  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3178 
3179  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3180  offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3181  offsets[getOutputCDim()]});
3182  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3183  sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3184  sizes[getOutputCDim()]});
3185 
3186  return success();
3187 }
3188 
3189 /// Implement tiling for winograd_input_transform
3190 /// The input of winograd_input_transform is (N, H, W, C).
3191 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3192 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3193 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3194 /// the values for the sizes of tileH, tileW, N, C for one tile.
3195 FailureOr<TilingResult>
3197  ArrayRef<OpFoldResult> offsets,
3198  ArrayRef<OpFoldResult> sizes) {
3199  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3200  int64_t m = getM();
3201  int64_t r = getR();
3202 
3203  ShapedType outputType = getOutputOperandType();
3204  ArrayRef<int64_t> outputShape = outputType.getShape();
3205  int64_t alphaH = outputShape[getOutputAlphaHDim()];
3206  int64_t alphaW = outputShape[getOutputAlphaWDim()];
3207 
3208  Location loc = getLoc();
3209  MLIRContext *context = builder.getContext();
3210  auto identityAffineMap =
3211  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3212  auto offsetAffineMap =
3213  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3214  Value mappedOffsetH = affine::makeComposedAffineApply(
3215  builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3216  offsets[getOutputTileHDim()]);
3217  Value mappedOffsetW = affine::makeComposedAffineApply(
3218  builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3219  offsets[getOutputTileWDim()]);
3220  auto sizeAffineMap = AffineMap::get(
3221  1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3222  Value mappedSizeH = affine::makeComposedAffineApply(
3223  builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3224  Value mappedSizeW = affine::makeComposedAffineApply(
3225  builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3226 
3227  SmallVector<Value> tiledOperands;
3228  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3229 
3230  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3231  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3232  sliceOffsets.append(
3233  {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3234  OpFoldResult sizeH =
3235  alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3236  OpFoldResult sizeW =
3237  alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3238  sliceSizes.append(
3239  {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3240  int64_t inputRank = getInputOperandRank();
3241  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3242  auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3243  loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3244  tiledOperands.emplace_back(inputSlice);
3245 
3246  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3247  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3248  resultSizes)))
3249  return failure();
3250 
3251  int64_t outputRank = getOutputOperandRank();
3252  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3253  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3254  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3255  tiledOperands.emplace_back(outputSlice);
3256 
3257  SmallVector<Type> resultTypes;
3258  resultTypes.push_back(tiledOperands[1].getType());
3259  Operation *tiledOp =
3260  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3261 
3262  return TilingResult{
3263  {tiledOp},
3264  SmallVector<Value>(tiledOp->getResults()),
3265  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3266 }
3267 
3268 //===----------------------------------------------------------------------===//
3269 // WinogradOutputTransformOp
3270 //===----------------------------------------------------------------------===//
3271 
3272 LogicalResult WinogradOutputTransformOp::verify() {
3273  auto valueType = cast<ShapedType>(getValue().getType());
3274  ArrayRef<int64_t> valueShape = valueType.getShape();
3275  int64_t valueH = valueShape[getValueAlphaHDim()];
3276  int64_t valueW = valueShape[getValueAlphaWDim()];
3277  int64_t valueTileH = valueShape[getValueTileHDim()];
3278  int64_t valueTileW = valueShape[getValueTileWDim()];
3279  int m = getM();
3280  int r = getR();
3281  bool leftTransform = valueH != 1;
3282  bool rightTransform = valueW != 1;
3283 
3284  int64_t outputRank = getOutputOperandRank();
3285  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3286  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3287  expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3288  } else {
3289  if (valueH != (leftTransform ? m + r - 1 : 1))
3290  return emitOpError("expect input height equals to input tile size");
3291  expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3292  }
3293  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3294  expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3295  } else {
3296  if (valueW != (rightTransform ? m + r - 1 : 1))
3297  return emitOpError("expect input width equals to input tile size");
3298  expectedOutputShape[getOutputWDim()] =
3299  (rightTransform ? m : 1) * valueTileW;
3300  }
3301  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3302  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3303 
3304  auto outputType = cast<ShapedType>(getOutput().getType());
3305  ArrayRef<int64_t> outputShape = outputType.getShape();
3306  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3307  return emitOpError("the output shape is not expected");
3308  }
3309  return success();
3310 }
3311 
3313 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3314  Location loc = getLoc();
3315  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3316  IntegerAttr oneAttr = builder.getIndexAttr(1);
3317  Value value = getValue();
3318  int64_t valueRank = getValueOperandRank();
3319  SmallVector<Range> loopBounds(valueRank);
3320  for (unsigned dim = 0; dim < valueRank; ++dim) {
3321  loopBounds[dim].offset = zeroAttr;
3322  // alphaH, alphaW, tileH, tileW, N, F
3323  loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3324  loopBounds[dim].stride = oneAttr;
3325  }
3326  return loopBounds;
3327 }
3328 
3330 WinogradOutputTransformOp::getLoopIteratorTypes() {
3331  int64_t valueRank = getValueOperandRank();
3332  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3333  utils::IteratorType::parallel);
3334  return iteratorTypes;
3335 }
3336 
3338  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3339  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3340  SmallVector<OpFoldResult> &resultSizes) {
3341  int64_t m = getM();
3342 
3343  Location loc = getLoc();
3344  MLIRContext *context = builder.getContext();
3345  auto identityAffineMap =
3346  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3347  auto affineMap =
3348  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3349 
3350  ShapedType valueType = getValueOperandType();
3351  ArrayRef<int64_t> valueShape = valueType.getShape();
3352  int64_t valueH = valueShape[0];
3353  int64_t valueW = valueShape[1];
3354  Value mappedOffsetH = affine::makeComposedAffineApply(
3355  builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3356  offsets[getValueTileHDim()]);
3357  Value mappedOffsetW = affine::makeComposedAffineApply(
3358  builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3359  offsets[getValueTileWDim()]);
3360  Value mappedSizeH = affine::makeComposedAffineApply(
3361  builder, loc, affineMap, sizes[getValueTileHDim()]);
3362  Value mappedSizeW = affine::makeComposedAffineApply(
3363  builder, loc, affineMap, sizes[getValueTileWDim()]);
3364 
3365  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3366  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3367  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3368  OpFoldResult sizeH =
3369  valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3370  OpFoldResult sizeW =
3371  valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3372 
3373  resultOffsets.append(
3374  {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3375  resultSizes.append(
3376  {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3377  return success();
3378 }
3379 
3380 /// Implement tiling for winograd_output_transform
3381 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3382 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3383 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3384 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3385 /// for the sizes of tileH, tileW, N, F for one tile.
3387  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3388  ArrayRef<OpFoldResult> sizes) {
3389  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3390  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3391  Location loc = getLoc();
3392  SmallVector<Value> tiledOperands;
3393  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3394 
3395  ShapedType valueType = getValueOperandType();
3396  ArrayRef<int64_t> valueShape = valueType.getShape();
3397  int64_t alphaH = valueShape[getValueAlphaHDim()];
3398  int64_t alphaW = valueShape[getValueAlphaWDim()];
3399  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3400  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3401 
3402  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3403  offsets[getValueTileWDim()], offsets[getValueNDim()],
3404  offsets[getValueFDim()]});
3405  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3406  sizes[getValueTileWDim()], sizes[getValueNDim()],
3407  sizes[getValueFDim()]});
3408  int64_t valueRank = getValueOperandRank();
3409  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3410  auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3411  loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3412  tiledOperands.emplace_back(valueSlice);
3413 
3414  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3415  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3416  resultSizes)))
3417  return failure();
3418 
3419  int64_t outputRank = getOutputOperandRank();
3420  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3421  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3422  loc, getOutput(), resultOffsets, resultSizes, strides);
3423  tiledOperands.emplace_back(outputSlice);
3424 
3425  SmallVector<Type> resultTypes;
3426  resultTypes.push_back(tiledOperands[1].getType());
3427  Operation *tiledOp =
3428  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3429 
3430  return TilingResult{
3431  {tiledOp},
3432  SmallVector<Value>(tiledOp->getResults()),
3433  llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3434 }
3435 
3436 //===----------------------------------------------------------------------===//
3437 // LinalgDialect
3438 // TODO: Merge with the LinalgDialect block at the bottom
3439 //===----------------------------------------------------------------------===//
3440 
3441 // Returns true if the result expression of `subMap` are a subset of `fullMap`.
3442 static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3443  auto explicitRange = subMap.getResults();
3444  auto defaultRange = fullMap.getResults();
3445  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3446  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3447  llvm::set_union(explicitSet, defaultSet);
3448  return explicitSet == defaultSet;
3449 }
3450 
3451 /// Check if the user defined map is valid broadcast map. Here broadcast
3452 /// indexing maps are defined in context of corresponding default indexing maps
3453 /// for the given Op. This way the check becomes very simple i.e just check the
3454 /// number of result dims.
3455 /// Returns true if the explictMap is broadcasted with respect to the
3456 /// defaultMap.
3457 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3458  return explictMap.getNumResults() < defaultMap.getNumResults();
3459 }
3460 
3461 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3462 /// indexing map for the MatmulOp \p op for each operand specified by \p
3463 /// opIndex.
3464 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3465  unsigned opIndex) {
3466  SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3467  SmallVector<AffineMap, 3> defaultIndexingMaps =
3468  matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3469 
3470  auto opIndexingMap = opIndexingMaps[opIndex];
3471  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3472  // Check general validity of indexing map results.
3473  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3474  return matmulOp->emitOpError()
3475  << "Unexpected dim expression in map result.";
3476 
3477  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3478  if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3479  return matmulOp->emitOpError()
3480  << "Invalid broadcast requested, should be (d2).";
3481  }
3482  return success();
3483  }
3484  return success();
3485 }
3486 
3487 // Check general validity of input indexing map.
3488 static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
3489  AffineMap opIndexingMap,
3490  AffineMap defaultIndexingMap, bool isLHS) {
3491  // Check the result dims are valid.
3492  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3493  return batchMatmulOp->emitOpError()
3494  << "Unexpected result dim expression (outside the set of default "
3495  "result dims).";
3496 
3497  // Check for valid number of result dims of input maps.
3498  if (opIndexingMap.getNumResults() > 3)
3499  return batchMatmulOp->emitOpError()
3500  << "no. of result dim expressions exceeds 3.";
3501 
3502  auto hasValidBatchDim = [](AffineMap map) {
3503  AffineExpr batchDim = map.getResult(0);
3504  return batchDim.isFunctionOfDim(0);
3505  };
3506 
3507  // Check if the requested broadcast is valid.
3508  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3509  if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3510  return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
3511  } else if (!hasValidBatchDim(opIndexingMap)) {
3512  return batchMatmulOp->emitOpError()
3513  << "Invalid batch dimension expression.";
3514  }
3515  return success();
3516 }
3517 
3518 /// This function checks if the given AffineMap for the output of a
3519 /// BatchMatmulOp has exactly 3 result dimensions and if the output map result
3520 /// dimensions are valid.
3521 static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
3522  AffineMap opIndexingMap) {
3523  if (opIndexingMap.getNumResults() != 3)
3524  return batchMatmulOp->emitOpError()
3525  << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3526  << ").";
3527 
3528  auto areValidOutputResultDim = [](AffineMap outputMap) {
3529  return outputMap.getResult(0).isFunctionOfDim(0) &&
3530  outputMap.getResult(1).isFunctionOfDim(1) &&
3531  outputMap.getResult(2).isFunctionOfDim(2);
3532  };
3533 
3534  if (!areValidOutputResultDim(opIndexingMap))
3535  return batchMatmulOp->emitOpError()
3536  << "Invalid output map result dimension.";
3537 
3538  return success();
3539 }
3540 
3541 /// Verifies the broadcast and transpose semantic specified by the explicit
3542 /// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
3543 static LogicalResult
3544 verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
3545  unsigned opIndex) {
3546  SmallVector<AffineMap, 3> opIndexingMaps =
3547  batchMatmulOp.getIndexingMapsArray();
3548  SmallVector<AffineMap, 3> defaultIndexingMaps =
3549  batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
3550 
3551  if (opIndexingMaps.size() != 3)
3552  return batchMatmulOp->emitOpError()
3553  << "Indexing_map attribute must have 3 affine maps.";
3554 
3555  auto opIndexingMap = opIndexingMaps[opIndex];
3556  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3557 
3558  if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
3559  return failure();
3560 
3561  if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
3562  opIndex == 0)))
3563  return failure();
3564 
3565  return success();
3566 }
3567 
3568 namespace mlir {
3569 namespace linalg {
3570 
3571 //===----------------------------------------------------------------------===//
3572 // MatMulOp
3573 //===----------------------------------------------------------------------===//
3574 
3575 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3576 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3577  AffineExpr d0, d1, d2;
3578  SmallVector<AffineMap> indexingMaps;
3579  bindDims(context, d0, d1, d2);
3580  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3581  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3582  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3583  return indexingMaps;
3584 }
3585 
3586 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3587  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3588  utils::IteratorType::parallel,
3589  utils::IteratorType::reduction};
3590 }
3591 
3592 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3593 
3594 std::string MatmulOp::getLibraryCallName() {
3595  return generateLibraryCallName(getOperation());
3596 }
3597 
3598 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3599 
3600 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3601 /// the user defined indexing maps are not equal to default map.
3602 bool MatmulOp::hasUserDefinedMaps() {
3603  SmallVector<AffineMap, 3> defaultMaps =
3604  getDefaultIndexingMaps(this->getContext());
3605  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3606  return defaultMaps != explicitMaps;
3607 }
3608 
3609 /// Implements the block region builder for the MatmulOp. This is called by
3610 /// 'fillStructuredOpRegion'.
3611 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3612  ArrayRef<NamedAttribute> attrs) {
3613  assert(3 > 0 && block.getNumArguments() == 3 &&
3614  "MatmulOp regionBuilder expects 3 (>=0) args");
3615  RegionBuilderHelper helper(b, block);
3616  SmallVector<Value> yields;
3617 
3618  TypeFn castVal = TypeFn::cast_signed;
3619  const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3620  return attr.getName() == "cast";
3621  });
3622  if (castIter != attrs.end()) {
3623  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3624  castVal = attr.getValue();
3625  }
3626 
3627  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3628  block.getArgument(0));
3629  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3630  block.getArgument(1));
3631  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3632  Value value4 =
3633  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3634  yields.push_back(value4);
3635  helper.yieldOutputs(yields);
3636 }
3637 
3638 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
3639 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3640  assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3641  AffineExpr exp = bcastMap.getResult(0);
3642  // Invalid map if the common dimension of matmul not found.
3643  return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
3644 }
3645 
3646 FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3647  if (parser.parseOptionalKeyword("indexing_maps"))
3648  return ArrayAttr{
3649  nullptr}; // Success in case indexing_maps was not provided.
3650 
3651  ArrayAttr arrayAttr;
3652  if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3653  return failure();
3654 
3655  if (llvm::any_of(arrayAttr,
3656  [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3657  return parser.emitError(parser.getCurrentLocation())
3658  << "element of indexing_maps array is not an affine_map";
3659 
3660  return arrayAttr;
3661 }
3662 
3663 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3664  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3665  if (failed(indexingMapsAttr))
3666  return failure();
3667 
3668  if (*indexingMapsAttr == nullptr) {
3669  auto indexingMapAttrs = llvm::map_to_vector(
3670  MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3671  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3672  indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3673  }
3674 
3675  result.addAttribute("indexing_maps", *indexingMapsAttr);
3676  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3677  MatmulOp::getRegionBuilder());
3678 }
3679 
3680 void MatmulOp::print(OpAsmPrinter &p) {
3681  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3682  MatmulOp::getDefaultIndexingMaps(getContext()),
3683  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3684  if (!llvm::equal(getIndexingMaps(), indexingMaps))
3685  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3686 
3687  std::array<StringRef, 3> elidedAttrs = {
3688  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3689  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3690  elidedAttrs);
3691 }
3692 
3693 /// Verify the user defined indexing maps.
3694 LogicalResult MatmulOp::verify() {
3695  // Verification of pure matmul is handled by verifyStructuredOpInterface().
3696  if (!hasUserDefinedMaps())
3697  return success();
3698 
3699  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3700  if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3701  return failure();
3702  }
3703  return success();
3704 }
3705 
3706 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3707  return memref::foldMemRefCast(*this);
3708 }
3709 
3710 void MatmulOp::getEffects(
3712  &effects) {
3713  if (hasPureTensorSemantics())
3714  return;
3715  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3716 }
3717 
3718 Speculation::Speculatability MatmulOp::getSpeculatability() {
3719  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3720 }
3721 
3722 //===----------------------------------------------------------------------===//
3723 // ContractOp
3724 //===----------------------------------------------------------------------===//
3725 
3726 SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
3727  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3728  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3729  // domains are all the same, and each implements a projected permutation.
3730  // Each iteration space dim must occur for at least one operand and either
3731  // takes part in a contraction/reduction or else has parallel iteration type.
3732  // We have that a dim is a contraction/reduction dim if and only if the dim
3733  // occurs for the output operand. We use this fact for fast inference:
3734  // NB: In case we allow dims to occur solely for one input, the above still
3735  // holds: per the einsum semantics, these are reduction dims as well.
3736  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
3737  for (auto result : outAffineMap.getResults()) {
3738  auto dimExpr = dyn_cast<AffineDimExpr>(result);
3739  assert(dimExpr && "affine_map is a projected permutation");
3740  dimsInOutput[dimExpr.getPosition()] = true;
3741  }
3742 
3743  SmallVector<utils::IteratorType> iteratorTypes;
3744  for (auto dimOccursInOutput : dimsInOutput)
3745  iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3746  : utils::IteratorType::reduction);
3747 
3748  return iteratorTypes;
3749 }
3750 
3751 unsigned ContractOp::getNumRegionArgs() { return 3; }
3752 
3753 /// Implement block region builder, which is called by 'fillStructuredOpRegion'.
3754 void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3755  ArrayRef<NamedAttribute> attrs) {
3756  assert(block.getNumArguments() == 3 &&
3757  "ContractOp regionBuilder expects 3 args");
3758  RegionBuilderHelper helper(b, block);
3759 
3760  TypeFn castSignedness = TypeFn::cast_signed;
3761  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3762  return attr.getName() == "cast";
3763  });
3764  if (castIter != attrs.end()) {
3765  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3766  castSignedness = attr.getValue();
3767  }
3768 
3769  // TODO: Support fields with operators besides mult & add.
3770  Type outType = block.getArgument(2).getType();
3771  Value lhsAtOutType =
3772  helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
3773  Value rhsAtOutType =
3774  helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
3775  Value productAtOutType =
3776  helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3777  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3778  productAtOutType);
3779  helper.yieldOutputs({result});
3780 }
3781 
3782 ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
3783  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3784  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
3785  return parser.emitError(parser.getCurrentLocation(),
3786  "expected 'indexing_maps' attribute");
3787  result.addAttribute("indexing_maps", *indexingMapsAttr);
3788 
3789  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
3790  regionBuilder);
3791 }
3792 
3794  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3796  p, getOperation(), getInputs(), getOutputs(),
3797  /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
3798 }
3799 
3800 LogicalResult ContractOp::verify() {
3801  int iterationSpaceDims = -1;
3802  // Map iter space dims to #occurrences in inputs' and output's affine_maps:
3803  // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
3804  // access an input operand (so occurrence count can be at most 2) and
3805  // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
3806  SmallVector<size_t> inOccurrences;
3807  SmallVector<size_t> outOccurrences;
3808 
3809  // A helper so that for each operand's affine_map and type we check that ...
3810  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
3811  bool isInput) -> LogicalResult {
3812  // ... the affine_map is a projected permutation;
3813  if (!affineMap.isProjectedPermutation())
3814  return emitError("provided affine_map is not a projected permutation");
3815 
3816  // ... the rank of the affine_map's results and corresponding type match;
3817  if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
3818  if (affineMap.getNumResults() != shapedType.getRank())
3819  return emitError("ranks of shaped operand and results of corresponding "
3820  "affine_map differ");
3821  } else if (affineMap.getNumResults() != 0) {
3822  return emitError("affine_map specifies shaped access while operand has "
3823  "non-shaped type");
3824  }
3825 
3826  // ... the rank of the affine_map's domain is the same as those seen prior;
3827  if (iterationSpaceDims == -1) {
3828  iterationSpaceDims = affineMap.getNumDims();
3829  inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3830  outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3831  } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
3832  return emitError("iteration spaces of provided affine_maps differ");
3833  }
3834 
3835  // ... update counts of dims used to access either an input or the output.
3836  for (AffineExpr affineExpr : affineMap.getResults()) {
3837  auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3838  if (!affineDimExpr)
3839  llvm_unreachable("affine_map is a projected permutation");
3840 
3841  if (isInput)
3842  inOccurrences[affineDimExpr.getPosition()] += 1;
3843  else
3844  outOccurrences[affineDimExpr.getPosition()] += 1;
3845  }
3846 
3847  return success();
3848  };
3849 
3850  for (auto &&[affineMap, operandType, isInput] :
3851  llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3852  SmallVector<bool>{true, true, false})) {
3853  if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3854  return failure(); // NB: checkAffineMapAndType will emit relevant error.
3855  }
3856 
3857  bool hasContractingDim = false;
3858  for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3859  size_t inOccCount = inOccurrences[dimIndex];
3860  size_t outOccCount = outOccurrences[dimIndex];
3861 
3862  // We have a contracting dim if and only if ...
3863  hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3864 
3865  if (inOccCount == 0 && outOccCount == 0)
3866  return emitError() << "iteration space dim at index " << dimIndex
3867  << " not used to access any operand";
3868 
3869  // NB: We disallow a dim which occurs for only one input operand and not
3870  // for the output. In terms of einsum semantics such dims have a
3871  // sensible meaning - namely an additional reduction per each such dim.
3872  // By contrast, the ContractionOpInterface does not know about this
3873  // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
3874  // while vector.contract's verifier accepts dims of this kind many of
3875  // its lowerings give up on encountering these dims.
3876  // TODO: Remove following once we have comprehensive support for input-only
3877  // reduction dims, at both the linalg- and vector-dialect levels.
3878  if (inOccCount == 1 && outOccCount != 1)
3879  return emitError()
3880  << "iteration space dim at index " << dimIndex
3881  << " is neither a contracting dim nor of parallel iteration type";
3882  }
3883 
3884  if (!hasContractingDim)
3885  return emitError("'indexing_maps' do not specify a contracting dimension");
3886 
3887  return success();
3888 }
3889 
3890 LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3891  return memref::foldMemRefCast(*this);
3892 }
3893 
3894 void ContractOp::getEffects(
3896  &effects) {
3897  if (hasPureTensorSemantics())
3898  return;
3899  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3900 }
3901 
3902 Speculation::Speculatability ContractOp::getSpeculatability() {
3903  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3904 }
3905 
3906 //===----------------------------------------------------------------------===//
3907 // Implementation of BatchMatmulOp
3908 //===----------------------------------------------------------------------===//
3910 BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3911  AffineExpr d0, d1, d2, d3;
3912  SmallVector<AffineMap> indexingMaps;
3913  bindDims(context, d0, d1, d2, d3);
3914  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
3915  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
3916  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
3917  return indexingMaps;
3918 }
3919 
3920 SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
3922  utils::IteratorType::parallel, utils::IteratorType::parallel,
3923  utils::IteratorType::parallel, utils::IteratorType::reduction};
3924 }
3925 
3926 unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
3927 
3928 std::string BatchMatmulOp::getLibraryCallName() {
3929  return generateLibraryCallName(getOperation());
3930 }
3931 
3932 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3933 /// the user defined indexing maps are not equal to default map.
3934 bool BatchMatmulOp::hasUserDefinedMaps() {
3935  SmallVector<AffineMap, 3> defaultMaps =
3936  getDefaultIndexingMaps(this->getContext());
3937  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3938  return defaultMaps != explicitMaps;
3939 }
3940 
3941 /// Returns true if the given broadcast map bcastMap is valid for this op.
3942 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
3943  assert(bcastMap.getNumResults() < 3 &&
3944  "Expected less than 3 result dim expr.");
3945  bool isValid = false;
3946  enum Indices { batchPos, mPos, nPos, kPos };
3947  if (bcastMap.getNumResults() == 1) {
3948  AffineExpr exp = bcastMap.getResult(0);
3949  isValid = exp.isFunctionOfDim(kPos);
3950  } else if (bcastMap.getNumResults() == 2) {
3951  AffineExpr exp0 = bcastMap.getResult(0);
3952  AffineExpr exp1 = bcastMap.getResult(1);
3953  isValid = isLHS
3954  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
3955  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
3956  }
3957  return isValid;
3958 }
3959 
3960 void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3961  ArrayRef<NamedAttribute> attrs) {
3962  assert(block.getNumArguments() == 3 &&
3963  "BatchMatmulOp regionBuilder expects 3 (>=0) args");
3964  RegionBuilderHelper helper(b, block);
3965  SmallVector<Value> yields;
3966 
3967  TypeFn castVal = TypeFn::cast_signed;
3968  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3969  return attr.getName() == "cast";
3970  });
3971  if (castIter != attrs.end()) {
3972  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3973  castVal = attr.getValue();
3974  }
3975 
3976  auto toType = block.getArgument(2).getType();
3977  Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
3978  Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
3979  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
3980  Value addVal =
3981  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
3982  yields.push_back(addVal);
3983  helper.yieldOutputs(yields);
3984 }
3985 
3986 ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3987  SmallVector<Attribute, 3> indexingMapsAttr;
3988  Attribute mapAttr;
3989  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
3990  if (parser.parseEqual())
3991  return failure();
3992 
3993  if (parser.parseLSquare())
3994  return failure();
3995 
3996  do {
3997  if (parser.parseAttribute(mapAttr))
3998  return failure();
3999  if (!isa<AffineMapAttr>(mapAttr)) {
4000  return parser.emitError(parser.getCurrentLocation(),
4001  "expected affine map attribute");
4002  }
4003  indexingMapsAttr.push_back(mapAttr);
4004 
4005  if (parser.parseOptionalComma())
4006  break;
4007  } while (true);
4008 
4009  if (parser.parseRSquare())
4010  return failure();
4011  }
4012  // Initialize indexingMaps, if not supplied explicitly.
4013  if (indexingMapsAttr.empty()) {
4014  indexingMapsAttr = llvm::map_to_vector(
4015  BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4016  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4017  }
4018  result.addAttribute("indexing_maps",
4019  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4020 
4021  return ::parseNamedStructuredOp(parser, result,
4022  BatchMatmulOp::getNumRegionArgs(),
4023  BatchMatmulOp::getRegionBuilder());
4024 }
4025 
4027  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4028  BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4029  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4030  if (!llvm::equal(getIndexingMaps(), indexingMaps))
4031  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4032 
4033  std::array<StringRef, 3> elidedAttrs = {
4034  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4035  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4036  elidedAttrs);
4037 }
4038 
4039 /// Verify the user defined indexing maps.
4040 LogicalResult BatchMatmulOp::verify() {
4041  // Verification of pure batch_matmul is handled by
4042  // verifyStructuredOpInterface().
4043  if (!hasUserDefinedMaps())
4044  return success();
4045 
4046  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4047  if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
4048  return failure();
4049  }
4050  return success();
4051 }
4052 
4053 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4055  return memref::foldMemRefCast(*this);
4056 }
4057 
4058 void BatchMatmulOp::getEffects(
4060  &effects) {
4061  if (hasPureTensorSemantics())
4062  return;
4063  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4064 }
4065 
4066 Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4067  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4068 }
4069 
4070 //===----------------------------------------------------------------------===//
4071 // ElementwiseOp
4072 //===----------------------------------------------------------------------===//
4073 //
4074 namespace {
4075 struct ArityGroupAndKind {
4076  // The enum class {Unary, Binary, Ternary, ..}
4077  ElementwiseArityGroup arityGroup;
4078 
4079  // The kind (e.g. `exp` or `add`) belonging to the arity group.
4080  union Kind {
4081  UnaryFn unaryFn;
4082  BinaryFn binaryFn;
4083  TernaryFn ternaryFn;
4084  } kind;
4085 };
4086 
4087 unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4088  return static_cast<unsigned>(arityGroup);
4089 }
4090 } // namespace
4091 
4092 static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4093  constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4094  constexpr int lastBinary =
4095  static_cast<int>(ElementwiseCaseLimits::LastBinary);
4096  constexpr int lastTernary =
4097  static_cast<int>(ElementwiseCaseLimits::LastTernary);
4098 
4099  int val = static_cast<int>(kind);
4100  ArityGroupAndKind result;
4101 
4102  if (val < lastUnary) {
4103  result.arityGroup = ElementwiseArityGroup::Unary;
4104  result.kind.unaryFn = static_cast<UnaryFn>(val);
4105  return result;
4106  }
4107  if (val < lastBinary) {
4108  result.arityGroup = ElementwiseArityGroup::Binary;
4109  result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4110  return result;
4111  }
4112  if (val >= lastTernary) {
4113  llvm_unreachable("unhandled ElementwiseFn");
4114  }
4115  result.arityGroup = ElementwiseArityGroup::Ternary;
4116  result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4117  return result;
4118 }
4119 
4120 SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4121  auto rank = getResultRank();
4122  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4123 }
4124 
4126 ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4127  MLIRContext *context) {
4128  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4129  return SmallVector<AffineMap>(numMaps, map);
4130 }
4131 
4132 ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4133  // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4134  Attribute attr;
4135  mlir::linalg::ElementwiseKind elemwiseKindVal;
4136  if (parser.parseKeyword("kind") || parser.parseEqual())
4137  return failure();
4138 
4139  if (succeeded(parser.parseAttribute(attr))) {
4140  auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4141  if (!elemwiseKindAttr)
4142  return parser.emitError(parser.getCurrentLocation(),
4143  "expected ElementwiseKind attribute");
4144  elemwiseKindVal = elemwiseKindAttr.getValue();
4145  } else {
4146  return parser.emitError(parser.getCurrentLocation(),
4147  "expected operation 'kind' attribute");
4148  }
4149  result.addAttribute(
4150  "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4151 
4152  // Parse optional `indexing_maps`
4153  SmallVector<Attribute, 3> indexingMapsAttr;
4154  Attribute mapAttr;
4155  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4156  if (parser.parseEqual())
4157  return failure();
4158  if (parser.parseLSquare())
4159  return failure();
4160  do {
4161  if (parser.parseAttribute(mapAttr))
4162  return failure();
4163  if (!isa<AffineMapAttr>(mapAttr))
4164  return parser.emitError(parser.getCurrentLocation(),
4165  "expected affine map attribute");
4166  indexingMapsAttr.push_back(mapAttr);
4167  if (parser.parseOptionalComma())
4168  break;
4169  } while (true);
4170  if (parser.parseRSquare())
4171  return failure();
4172  }
4173  // At this stage of parsing the only way to infer number of region
4174  // args is through op kind, as input output tensors are not parsed yet.
4175  auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4176  int numRegionArgs =
4177  getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4178  if (parseNamedStructuredOp(parser, result, numRegionArgs,
4179  ElementwiseOp::getRegionBuilder())) {
4180  return parser.emitError(parser.getCurrentLocation(),
4181  "unable to parse elemwise op");
4182  }
4183 
4184  // Initialize indexingMaps, if not supplied explicitly.
4185  if (indexingMapsAttr.empty()) {
4186  // We need to infer the numDims of the indexing maps from the output
4187  // type which is already parsed by now.
4188  auto resultType = result.operands[result.operands.size() - 1].getType();
4189  auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4190  if (!shapedType)
4191  return parser.emitError(parser.getCurrentLocation(),
4192  "return type needs to be shaped type");
4193  auto numDims = shapedType.getRank();
4194  indexingMapsAttr = llvm::map_to_vector(
4195  ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4196  parser.getContext()),
4197  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4198  }
4199 
4200  result.addAttribute("indexing_maps",
4201  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4202  return success();
4203 }
4204 
4206  p << " kind=";
4207  p.printAttribute(getKindAttr());
4208  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4209  "indexing_maps"};
4210  unsigned arity =
4211  getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4212  unsigned numDims = getResultRank();
4213 
4214  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4215  ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4216  getContext()),
4217  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4218 
4219  if (!llvm::equal(getIndexingMaps(), indexingMaps))
4220  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4221 
4222  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4223  elidedAttrs);
4224 }
4225 
4226 LogicalResult ElementwiseOp::verify() {
4227  // All necessary checks are done either by
4228  // - EnumAttr (e.g. unknown operation kind)
4229  // - verifyStructuredOpInterface (incorrect map, sizes).
4230  return success();
4231 }
4232 
4233 /// Implements the block region builder for the ElementwiseOp. This is called by
4234 /// 'fillStructuredOpRegion'.
4235 void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4236  ArrayRef<NamedAttribute> attrs) {
4237  ElementwiseKind elemwiseKind;
4238  for (auto attr : attrs) {
4239  if (attr.getName() == b.getStringAttr("kind")) {
4240  auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4241  assert(kindAttr && "op kind attribute incorrectly set");
4242  elemwiseKind = kindAttr.getValue();
4243  break;
4244  }
4245  }
4246 
4247  ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4248  auto arityGroup = groupAndKind.arityGroup;
4249  auto kind = groupAndKind.kind;
4250  assert(block.getNumArguments() ==
4251  getArityGroupAsUInt(arityGroup) + 1 /*output*/
4252  && "Elementwise regionBuilder number of block args mismatch");
4253 
4254  RegionBuilderHelper helper(b, block);
4255  SmallVector<Value> yields;
4256  Value result;
4257 
4258  if (arityGroup == ElementwiseArityGroup::Unary) {
4259  result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4260 
4261  } else if (arityGroup == ElementwiseArityGroup::Binary) {
4262  result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4263  block.getArgument(1));
4264 
4265  } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4266  result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4267  block.getArgument(1), block.getArgument(2));
4268 
4269  } else
4270  assert(false && "found unhandled category in elemwise");
4271 
4272  yields.push_back(result);
4273  helper.yieldOutputs(yields);
4274 }
4275 
4276 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4278  return memref::foldMemRefCast(*this);
4279 }
4280 
4281 void ElementwiseOp::getEffects(
4283  &effects) {
4284  if (hasPureTensorSemantics())
4285  return;
4286  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4287 }
4288 
4289 Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4290  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4291 }
4292 
4293 //===----------------------------------------------------------------------===//
4294 // PackOp/UnPackOp Common
4295 //===----------------------------------------------------------------------===//
4296 // Given the (potentially) updated packed type, `newPackedTy`, generates an
4297 // updated mixed-tile-sizes attribute. A tile size is updated only
4298 // when:
4299 // * a dim from newPackedTy is static, and
4300 // * the corresponding size from mixedTiles is still dynamic.
4301 // Otherwise, the original tile size is preserved.
4302 // Note - packed-type-dim and mixed-tile-size should always match!
4305  SmallVector<OpFoldResult> mixedTiles) {
4306  SmallVector<OpFoldResult> newMixedTileSizes;
4307  for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4308  .getShape()
4309  .take_back(mixedTiles.size()),
4310  mixedTiles)) {
4311  int64_t shape = std::get<0>(it);
4312  if (shape == ShapedType::kDynamic) {
4313  newMixedTileSizes.push_back(std::get<1>(it));
4314  continue;
4315  }
4316 
4317  // If the current result dim is static, update the dynamic mixed-size
4318  // (provided the original value is dynamic).
4319  OpFoldResult tile = std::get<1>(it);
4320  if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4321  // Already a constant
4322  newMixedTileSizes.push_back(tile);
4323  } else {
4324  assert(getConstantIntValue(tile).value() == shape &&
4325  "tile size and dim size don't match!");
4326  newMixedTileSizes.push_back(
4327  (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4328  }
4329  }
4330 
4331  return newMixedTileSizes;
4332 }
4333 
4334 template <typename OpTy>
4335 static LogicalResult
4337  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4338  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4339  "applies to only pack or unpack operations");
4340  int64_t destRank = op.getDestRank();
4341  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
4342  reifiedReturnShapes[0] =
4343  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
4344  return success();
4345 }
4346 
4347 template <typename OpTy>
4349  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4350  "applies to only pack or unpack operations");
4351  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
4352  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
4353  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
4354  assert(tiles.size() == dimsToTile.size() &&
4355  "tiles must match indices of dimension to block");
4356  // bind the dimension `i` with the tile factor.
4357  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4358  dimAndTileMapping[dimsToTile[i]] = tiles[i];
4359  return dimAndTileMapping;
4360 }
4361 
4362 template <typename OpTy>
4364  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4365  "applies to only pack or unpack operations");
4366  Builder builder(op);
4367  SmallVector<OpFoldResult> mixedInnerTiles;
4368  unsigned dynamicValIndex = 0;
4369  for (int64_t staticTile : op.getStaticInnerTiles()) {
4370  if (!ShapedType::isDynamic(staticTile))
4371  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
4372  else
4373  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4374  }
4375  return mixedInnerTiles;
4376 }
4377 
4378 template <typename OpTy>
4380  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4381  "applies to only pack or unpack operations");
4382  SmallVector<Value> dynamicTiles;
4383  SmallVector<int64_t> staticTiles;
4384  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
4385  return staticTiles;
4386 }
4387 
4388 /// Returns true if `dimsPos` is invalid. It is invalid when:
4389 /// a) It contains duplicate.
4390 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
4391 /// c) The number of elements in `dimsPos` is > than `rank`.
4393  size_t rank) {
4394  size_t dimsPosSize = dimsPos.size();
4395  if (dimsPosSize > rank)
4396  return true;
4397  DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
4398  if (dimsPosSize != uniqued.size())
4399  return true;
4400  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4401  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
4402  });
4403 }
4404 
4405 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
4406 /// of the `limitShape`.
4407 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
4408  ArrayRef<int64_t> limitShape) {
4409  assert(
4410  sourceShape.size() == limitShape.size() &&
4411  "expected source shape rank, and limit of the shape to have same rank");
4412  return llvm::all_of(
4413  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4414  int64_t sourceExtent = std::get<0>(it);
4415  int64_t limit = std::get<1>(it);
4416  return ShapedType::isDynamic(sourceExtent) ||
4417  ShapedType::isDynamic(limit) || sourceExtent <= limit;
4418  });
4419 }
4420 
4421 template <typename OpTy>
4422 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
4423  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4424  "applies to only pack or unpack operations");
4425  Operation *op = packOrUnPack.getOperation();
4426 
4427  // Return true if we have a zero-value tile.
4428  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
4429  return llvm::any_of(
4430  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
4431  };
4432 
4433  // Verify tiles. Do not allow zero tiles.
4434  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
4435  if (hasZeros(mixedTiles))
4436  return op->emitError("invalid zero tile factor");
4437 
4438  // Verify inner_dims_pos and outer_dims_perm.
4439  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4440  ? packOrUnPack.getSourceType()
4441  : packOrUnPack.getDestType();
4442  size_t unpackedRank = unpackedType.getRank();
4443  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
4444  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
4446  return op->emitError("invalid inner_dims_pos vector");
4447  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
4448  return op->emitError("invalid outer_dims_perm vector");
4449  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4450  return op->emitError("outer_dims_perm must be a permutation or empty");
4451 
4452  // Tiling factors must be less than or equal to the input rank for pack (or
4453  // output rank for unpack), and must match the number of `inner_dims_pos`.
4454  if (mixedTiles.size() > unpackedRank) {
4455  return op->emitError("tiling factors must be less than or equal to the "
4456  "input rank for pack or output rank for unpack");
4457  }
4458  if (mixedTiles.size() != innerDimsPos.size()) {
4459  return op->emitError(
4460  "tiling factors must equal the number of dimensions to tile");
4461  }
4462 
4463  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4464  ? packOrUnPack.getDestType()
4465  : packOrUnPack.getSourceType();
4466  size_t packedRank = packedType.getRank();
4467  // Require output rank to match input rank + number of blocking factors.
4468  size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4469  if (expectedPackedRank != packedRank) {
4470  return op->emitError(
4471  "packed rank != (unpacked rank + num tiling factors), got ")
4472  << packedRank << " != " << expectedPackedRank;
4473  }
4474 
4475  // Verify result shape is greater than the minimum expected
4476  // by the pack operation, and that the output shape
4477  // represents full tiles.
4478  RankedTensorType expectedPackedType = PackOp::inferPackedType(
4479  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4480  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4481  return op->emitError("the shape of output is not large enough to hold the "
4482  "packed data. Expected at least ")
4483  << expectedPackedType << ", got " << packedType;
4484  }
4485  if (!llvm::all_of(
4486  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4487  mixedTiles),
4488  [](std::tuple<int64_t, OpFoldResult> it) {
4489  int64_t shape = std::get<0>(it);
4490  if (Attribute attr =
4491  llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4492  IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4493  int64_t staticTileSize = intAttr.getValue().getSExtValue();
4494  return shape == staticTileSize;
4495  }
4496  return ShapedType::isDynamic(shape);
4497  })) {
4498  return op->emitError("mismatch in inner tile sizes specified and shaped of "
4499  "tiled dimension in the packed type");
4500  }
4501  return success();
4502 }
4503 
4504 namespace {
4505 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
4506 /// various permutations to the op.
4507 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4508 // these. These may or may not become true foldings / canonicalizations
4509 // depending on how aggressive we want to be in automatically folding
4510 // transposes.
4511 struct PackOrUnPackTransposeResult {
4515 };
4516 } // namespace
4517 
4518 template <typename OpTy>
4519 static PackOrUnPackTransposeResult
4521  ArrayRef<int64_t> innerPermutation,
4522  ArrayRef<int64_t> outerPermutation) {
4523  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4524  "applies to only pack or unpack operations");
4525  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4526  "some permutation must be non-empty");
4527  PackOrUnPackTransposeResult metadata;
4528  metadata.innerDimsPos =
4529  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
4530  metadata.innerTiles =
4531  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
4532  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4533  ? packOrUnPackOp.getSourceRank()
4534  : packOrUnPackOp.getDestRank();
4535  metadata.outerDimsPerm =
4536  packOrUnPackOp.getOuterDimsPerm().empty()
4537  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4538  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
4539  if (!innerPermutation.empty()) {
4540  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4541  isPermutationVector(innerPermutation) &&
4542  "invalid inner permutation");
4543  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
4544  applyPermutationToVector(metadata.innerTiles, innerPermutation);
4545  }
4546  if (!outerPermutation.empty()) {
4547  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4548  isPermutationVector(outerPermutation) &&
4549  "invalid outer permutation");
4550  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
4551  }
4552  return metadata;
4553 }
4554 
4555 //===----------------------------------------------------------------------===//
4556 // PackOp
4557 //===----------------------------------------------------------------------===//
4558 
4559 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4560  setNameFn(getResult(), "pack");
4561 }
4562 
4563 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
4566  std::optional<Value> paddingValue,
4568  assert(innerDimsPos.size() == innerTiles.size() &&
4569  "number of tile sizes specified must match the specified number of "
4570  "original dimensions to be tiled");
4571  SmallVector<int64_t> staticTileSizes;
4572  SmallVector<Value> dynamicTileSizes;
4573  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4574  build(builder, state, dest.getType(), source, dest,
4575  paddingValue ? *paddingValue : nullptr,
4576  outerDimsPerm.empty() ? nullptr
4578  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4579  builder.getDenseI64ArrayAttr(staticTileSizes));
4580 }
4581 
4582 LogicalResult
4584  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4585  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4586 }
4587 
4588 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
4589  return getDimAndTileMappingImpl(*this);
4590 }
4591 
4592 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
4593  return getMixedTilesImpl(*this);
4594 }
4595 
4596 SmallVector<int64_t> PackOp::getStaticTiles() {
4597  return getStaticTilesImpl(*this);
4598 }
4599 
4600 ArrayRef<int64_t> PackOp::getAllOuterDims() {
4601  ShapedType inputType = getSourceType();
4602  int64_t inputRank = inputType.getRank();
4603  return getDestType().getShape().take_front(inputRank);
4604 }
4605 
4606 SmallVector<int64_t> PackOp::getTiledOuterDims() {
4607  auto innerDimsPos = getInnerDimsPos();
4608  auto packedShape = getDestType().getShape();
4610 
4611  for (auto index : innerDimsPos)
4612  res.push_back(packedShape[index]);
4613 
4614  return res;
4615 }
4616 
4617 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
4619  ArrayRef<int64_t> outputShape,
4622  SmallVector<int64_t> outputTileSizes(
4623  outputShape.take_front(inputShape.size()));
4624  if (!outerDimsPerm.empty()) {
4625  assert(outerDimsPerm.size() == outputTileSizes.size() &&
4626  "expected output and outer_dims_perm to have same size");
4627  applyPermutationToVector(outputTileSizes,
4629  }
4630  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4631  if (ShapedType::isDynamic(inputShape[pos]))
4632  continue;
4633  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
4634 
4635  if (!constantTile) {
4636  if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4637  (inputShape[pos] % outputTileSizes[pos] != 0))
4638  return true;
4639  } else if (inputShape[pos] % (*constantTile) != 0) {
4640  return true;
4641  }
4642  }
4643  return false;
4644 }
4645 
4646 LogicalResult PackOp::verify() {
4647  if (failed(commonVerifierPackAndUnPackOp(*this)))
4648  return failure();
4649 
4650  // Verify padding value, and bail out if the tile does not divide the
4651  // dimension fully. In the case of dynamic tile factors or dimensions, having
4652  // a partial tile is undefined behavior.
4653  auto paddingValue = getPaddingValue();
4654  if (paddingValue &&
4655  paddingValue.getType() != getSourceType().getElementType()) {
4656  return emitOpError("expected padding_value has ")
4657  << getSourceType().getElementType()
4658  << " but got: " << paddingValue.getType();
4659  }
4660 
4661  if (!paddingValue &&
4662  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4663  getDestType().getShape(), getOuterDimsPerm(),
4664  getMixedTiles())) {
4665  return emitOpError(
4666  "invalid tile factor or output size provided. Only full tiles are "
4667  "supported when padding_value is not set");
4668  }
4669  return success();
4670 }
4671 
4672 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4673 /// Value's to kDynamic, even if they are arith.constant values.
4674 static SmallVector<int64_t>
4676  SmallVector<int64_t> result;
4677  for (auto o : ofrs) {
4678  // Have to do this first, as getConstantIntValue special-cases constants.
4679  if (llvm::dyn_cast_if_present<Value>(o))
4680  result.push_back(ShapedType::kDynamic);
4681  else
4682  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4683  }
4684  return result;
4685 }
4686 
4687 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4688 /// the packed type. Having a shared helper helps implement these two methods in
4689 /// a way that ensures that they agree on which dimensions are dynamic.
4691  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4693  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4694  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4695  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4696  continue;
4697  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4698  resultShape[tiledDim.value()] = ShapedType::kDynamic;
4699  continue;
4700  }
4701  resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4702  resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4703  }
4704 
4705  // Swap tile loops if outer_dims_perm is available.
4706  if (!outerDimsPerm.empty())
4708 
4709  // Append the inner tile dimensions.
4710  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4711  return resultShape;
4712 }
4713 
4714 SmallVector<OpFoldResult> PackOp::getResultShape(
4715  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4718  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4719 
4720  AffineExpr s0, s1;
4721  bindSymbols(builder.getContext(), s0, s1);
4722  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4723  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4724  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4725  builder, loc, ceilDivExpr,
4726  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4727  }
4728  if (!outerDimsPerm.empty())
4730  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4731 
4732  SmallVector<int64_t> resultTypeShape =
4734  asShapeWithAnyValueAsDynamic(innerTileSizes),
4736 
4737  // Fix-up `resultDims` to ensure that they are Value's if and only if the
4738  // result type shape says it's a dynamic dim. This is needed as callers may
4739  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4740  // dynamic dims returned by that.
4741  for (unsigned i = 0; i < resultDims.size(); ++i) {
4742  if (!ShapedType::isDynamic(resultTypeShape[i]))
4743  continue;
4744  resultDims[i] =
4745  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4746  }
4747 
4748  return resultDims;
4749 }
4750 
4751 /// Get the expected packed type based on source type, tile factors, position of
4752 /// the inner tiles and permutation of the outer tiled loop.
4753 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4754  ArrayRef<int64_t> innerTileSizes,
4758  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4759  return RankedTensorType::get(resultShape, sourceType.getElementType());
4760 }
4761 
4762 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4763  ArrayRef<OpFoldResult> innerTileSizes,
4766  AffineExpr dim0, dim1;
4767  bindDims(b.getContext(), dim0, dim1);
4768  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4769  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4770  {v1, v2});
4771  };
4772 
4773  SmallVector<OpFoldResult> mixedSizes;
4774  for (auto [index, value] : llvm::enumerate(
4775  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4776  if (ShapedType::isDynamic(value))
4777  mixedSizes.push_back(
4778  b.create<tensor::DimOp>(loc, source, index).getResult());
4779  else
4780  mixedSizes.push_back(b.getIndexAttr(value));
4781  }
4782  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4783  int64_t dimPos = std::get<0>(it);
4784  OpFoldResult tileSize = std::get<1>(it);
4785  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4786  }
4787  if (!outerDimsPerm.empty())
4788  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4789 
4790  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4791  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4792  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4793 }
4794 
4795 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4796  ArrayRef<int64_t> innerPermutation,
4797  ArrayRef<int64_t> outerPermutation) {
4798  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4799  *this, innerPermutation, outerPermutation);
4800  Value transposedDest =
4801  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4802  metadata.innerDimsPos, metadata.outerDimsPerm);
4803  return b.create<PackOp>(loc, getSource(), transposedDest,
4804  metadata.innerDimsPos, metadata.innerTiles,
4805  getPaddingValue(), metadata.outerDimsPerm);
4806 }
4807 
4808 /// Returns true if the tiles and the tiled dims are constant.
4809 template <typename OpTy>
4811  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4812  "applies to only pack or unpack operations");
4813  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4814  ? op.getDestType()
4815  : op.getSourceType();
4816  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
4817  for (auto [dimDest, tile] : llvm::zip(
4818  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4819  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
4820  if (!constTileSize || ShapedType::isDynamic(dimDest))
4821  return false;
4822  }
4823  return true;
4824 }
4825 
4826 Speculation::Speculatability PackOp::getSpeculatability() {
4827  if (getPaddingValue())
4829 
4830  // The verifier rejects already operations if we can statically prove that the
4831  // sizes of the tiles do not divide perfectly the dimension; thus, check only
4832  // to have constant tiles and tiled inner dimensions.
4833  if (!areTilesAndTiledDimsAllConstant(*this))
4835 
4837 }
4838 
4839 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4840 // dimensions for pack and unpack.
4841 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4842  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4843  return false;
4844  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4845  return true;
4846  // Outer dims permutation is optional.
4847  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4848  // identity permutation.
4849  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4850  isIdentityPermutation(unPackOp.getOuterDimsPerm());
4851 }
4852 
4853 // Return true if pack and unpack have the same tiles.
4854 // Same SSA values or same integer constants.
4855 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4856  auto packTiles = packOp.getMixedTiles();
4857  auto unPackTiles = unPackOp.getMixedTiles();
4858  if (packTiles.size() != unPackTiles.size())
4859  return false;
4860  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4861  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4862  return false;
4863  }
4864  return true;
4865 }
4866 
4867 /// Returns true if the pack op does not need a padding value.
4868 static bool paddingIsNotNeeded(PackOp op) {
4869  auto srcType = op.getSourceType();
4870  if (llvm::any_of(op.getInnerDimsPos(),
4871  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4872  return false;
4873  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4874  return false;
4875  return !PackOp::requirePaddingValue(
4876  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4877  op.getOuterDimsPerm(), op.getMixedTiles());
4878 }
4879 
4880 /// Returns true if the `srcShape` or `destShape` is different from the one in
4881 /// `packOp` and populates each with the inferred static shape.
4882 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4883  SmallVectorImpl<int64_t> &destShape) {
4884  bool changeNeeded = false;
4885  srcShape.assign(packOp.getSourceType().getShape().begin(),
4886  packOp.getSourceType().getShape().end());
4887  destShape.assign(packOp.getDestType().getShape().begin(),
4888  packOp.getDestType().getShape().end());
4889  llvm::SmallSetVector<int64_t, 4> innerDims;
4890  innerDims.insert_range(packOp.getInnerDimsPos());
4891  SmallVector<int64_t> inverseOuterDimsPerm;
4892  if (!packOp.getOuterDimsPerm().empty())
4893  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4894  int srcRank = packOp.getSourceRank();
4895  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4896  if (innerDims.contains(i))
4897  continue;
4898  int64_t srcPos = i;
4899  int64_t destPos = i;
4900  if (!inverseOuterDimsPerm.empty())
4901  destPos = inverseOuterDimsPerm[srcPos];
4902  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4903  ShapedType::isDynamic(destShape[destPos])) {
4904  continue;
4905  }
4906  int64_t size = srcShape[srcPos];
4907  if (ShapedType::isDynamic(size))
4908  size = destShape[destPos];
4909  srcShape[srcPos] = size;
4910  destShape[destPos] = size;
4911  changeNeeded = true;
4912  }
4913  return changeNeeded;
4914 }
4915 
4916 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4917  // Fold an pack(unpack(x)) to x.
4918  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4919  if (unPackOp.getSourceType() != packOp.getDestType())
4920  return failure();
4921  if (packOp.getPaddingValue() ||
4922  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4923  !haveSameTiles(packOp, unPackOp))
4924  return failure();
4925  rewriter.replaceOp(packOp, unPackOp.getSource());
4926  return success();
4927  }
4928 
4929  // Fold optional PaddingValue operand away if padding is not needed.
4930  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4931  rewriter.startOpModification(packOp);
4932  packOp.getPaddingValueMutable().clear();
4933  rewriter.finalizeOpModification(packOp);
4934  return success();
4935  }
4936 
4937  // Insert tensor.cast ops if static shape inference is available..
4938  SmallVector<int64_t> srcShape, destShape;
4939  if (inferStaticShape(packOp, srcShape, destShape)) {
4940  Location loc = packOp.getLoc();
4941  Value source = packOp.getSource();
4942  if (srcShape != packOp.getSourceType().getShape()) {
4943  auto newSrcType = packOp.getSourceType().clone(srcShape);
4944  source =
4945  rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4946  }
4947  Value dest = packOp.getDest();
4948  RankedTensorType originalResultType = packOp.getDestType();
4949  bool needUpdateDestType = (destShape != originalResultType.getShape());
4950  if (needUpdateDestType) {
4951  auto newDestType = packOp.getDestType().clone(destShape);
4952  dest =
4953  rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4954  }
4955  rewriter.modifyOpInPlace(packOp, [&] {
4956  packOp.getSourceMutable().assign(source);
4957  packOp.getDestMutable().assign(dest);
4958  packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
4959  });
4960  // Insert a cast if needed
4961  if (needUpdateDestType) {
4962  rewriter.setInsertionPointAfter(packOp);
4963  auto castOp =
4964  rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
4965  rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
4966  }
4967  return success();
4968  }
4969 
4970  return failure();
4971 }
4972 
4973 template <typename PackOrUnpackOp>
4974 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4975  RankedTensorType packedTensorType) {
4976  static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
4977  std::is_same<PackOrUnpackOp, UnPackOp>::value,
4978  "Function meant for pack/unpack");
4979  // This is a pad if packing only adds ones and we don't transpose dimensions.
4980 
4981  // Check that we are not transposing any dimensions.
4982  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4983  int64_t numPackedDims = innerDimsPos.size();
4984  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4985  if (orderedDims != innerDimsPos) {
4986  // Dimensions don't happen in order.
4987  return false;
4988  }
4989 
4990  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
4991  int64_t packedRank = packedTensorType.getRank();
4992  // At this point we know that we are taking numPackedDims outer
4993  // dimensions and pushing them all the way as the inner most dimensions.
4994  // What's left on the outer most dimensions is, in this order:
4995  // - the factor of the packed dimensions, then
4996  // - the untouched dimensions
4997  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4998  // if all the dimensions that bubble outerward are ones.
4999  // Therefore check that all the dimensions but the numPackedDims inner most
5000  // ones are ones.
5001  return llvm::all_of(
5002  llvm::seq<int64_t>(0, packedRank - numPackedDims),
5003  [&packedShape](int64_t i) { return packedShape[i] == 1; });
5004 }
5005 
5006 bool PackOp::isLikePad() {
5007  auto packedTensorType =
5008  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5009  return isLikePadUnPad(*this, packedTensorType);
5010 }
5011 
5012 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5013  std::optional<Attribute> paddingValue;
5014  if (auto pad = adaptor.getPaddingValue())
5015  paddingValue = pad;
5016  if (OpFoldResult reshapedSource = reshapeConstantSource(
5017  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5018  getDestType(), paddingValue))
5019  return reshapedSource;
5020  return {};
5021 }
5022 
5023 /// Folds a tensor.cast op into a consuming PackOp op if the
5024 /// `tensor.cast` has source that is more static than the consuming op.
5025 ///
5026 /// Example:
5027 /// ```mlir
5028 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5029 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5030 /// ```
5031 ///
5032 /// folds into:
5033 ///
5034 /// ```mlir
5035 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5036 /// ```
5037 struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
5039 
5040  LogicalResult matchAndRewrite(PackOp op,
5041  PatternRewriter &rewriter) const override {
5043  return failure();
5044 
5045  SmallVector<Type> newResultTypes(op->getResultTypes());
5046  SmallVector<Value> newOperands =
5048 
5049  // Get the updated mixed-tile-sizes attribute.
5050  SmallVector<OpFoldResult> newMixedTileSizes =
5051  getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
5052 
5053  // Clone op.
5054  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5055  // this point. However, in practice, we use them for things that we'd like
5056  // to preserve. Implement a better abstraction.
5057  PackOp newOp = rewriter.create<PackOp>(
5058  op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
5059  newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
5060  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5061 
5062  // Replace op.
5063  Value oldResult = op.getResult();
5064  Value newResult = newOp.getResult();
5065  Value replacement = (newResult.getType() != oldResult.getType())
5066  ? rewriter.create<tensor::CastOp>(
5067  op->getLoc(), oldResult.getType(), newResult)
5068  : newResult;
5069 
5070  rewriter.replaceOp(op, {replacement});
5071 
5072  return success();
5073  }
5074 };
5075 
5076 //===----------------------------------------------------------------------===//
5077 // UnPackOp
5078 //===----------------------------------------------------------------------===//
5079 
5080 void UnPackOp::getAsmResultNames(
5081  function_ref<void(Value, StringRef)> setNameFn) {
5082  setNameFn(getResult(), "unpack");
5083 }
5084 
5085 LogicalResult
5087  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5088  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5089 }
5090 
5091 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
5092  return getDimAndTileMappingImpl(*this);
5093 }
5094 
5095 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
5096  return getMixedTilesImpl(*this);
5097 }
5098 
5099 SmallVector<int64_t> UnPackOp::getStaticTiles() {
5100  return getStaticTilesImpl(*this);
5101 }
5102 
5103 ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
5104  ShapedType destType = getDestType();
5105  int64_t destRank = destType.getRank();
5106  return getSourceType().getShape().take_front(destRank);
5107 }
5108 
5109 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
5110  auto innerDimsPos = getInnerDimsPos();
5111  auto packedShape = getSourceType().getShape();
5113 
5114  for (auto index : innerDimsPos)
5115  res.push_back(packedShape[index]);
5116 
5117  return res;
5118 }
5119 
5120 LogicalResult UnPackOp::verify() {
5121  return commonVerifierPackAndUnPackOp(*this);
5122 }
5123 
5124 Speculation::Speculatability UnPackOp::getSpeculatability() {
5125  // See PackOp::getSpeculatability.
5126  if (!areTilesAndTiledDimsAllConstant(*this))
5128 
5130 }
5131 
5132 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
5136  assert(innerDimsPos.size() == innerTiles.size() &&
5137  "number of tile sizes specified must match the specified number of "
5138  "original dimensions to be tiled");
5139  SmallVector<int64_t> staticTileSizes;
5140  SmallVector<Value> dynamicTileSizes;
5141  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5142  build(builder, state, dest.getType(), source, dest,
5143  outerDimsPerm.empty() ? nullptr
5145  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5146  builder.getDenseI64ArrayAttr(staticTileSizes));
5147 }
5148 
5149 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
5150  Value source,
5151  ArrayRef<OpFoldResult> innerTileSizes,
5154  AffineExpr sym0, sym1;
5155  bindSymbols(b.getContext(), sym0, sym1);
5156  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5157  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
5158  };
5159 
5160  SmallVector<OpFoldResult> mixedSizes;
5161  auto srcType = llvm::cast<RankedTensorType>(source.getType());
5162  for (auto i :
5163  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5164  if (srcType.isDynamicDim(i))
5165  mixedSizes.push_back(b.create<tensor::DimOp>(loc, source, i).getResult());
5166  else
5167  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
5168  }
5169  if (!outerDimsPerm.empty()) {
5170  applyPermutationToVector<OpFoldResult>(
5171  mixedSizes, invertPermutationVector(outerDimsPerm));
5172  }
5173 
5174  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5175  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5176 
5177  auto elemType = srcType.getElementType();
5178  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
5179 }
5180 
5181 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5182  Value transposedSource,
5183  ArrayRef<int64_t> innerPermutation,
5184  ArrayRef<int64_t> outerPermutation) {
5185  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5186  *this, innerPermutation, outerPermutation);
5187  return b.create<UnPackOp>(loc, transposedSource, getDest(),
5188  metadata.innerDimsPos, metadata.innerTiles,
5189  metadata.outerDimsPerm);
5190 }
5191 
5192 /// Returns true if the `srcShape` or `destShape` is different from the one in
5193 /// `op` and populates each with the inferred static shape.
5194 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
5195  SmallVectorImpl<int64_t> &destShape) {
5196  bool changeNeeded = false;
5197  srcShape.assign(op.getSourceType().getShape().begin(),
5198  op.getSourceType().getShape().end());
5199  destShape.assign(op.getDestType().getShape().begin(),
5200  op.getDestType().getShape().end());
5201  llvm::SmallSetVector<int64_t, 4> innerDims;
5202  innerDims.insert_range(op.getInnerDimsPos());
5203  SmallVector<int64_t> inverseOuterDimsPerm;
5204  if (!op.getOuterDimsPerm().empty())
5205  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
5206  int destRank = op.getDestRank();
5207  for (auto i : llvm::seq<int64_t>(0, destRank)) {
5208  if (innerDims.contains(i))
5209  continue;
5210  int64_t srcPos = i;
5211  int64_t destPos = i;
5212  if (!inverseOuterDimsPerm.empty())
5213  srcPos = inverseOuterDimsPerm[destPos];
5214  if (ShapedType::isDynamic(srcShape[srcPos]) ==
5215  ShapedType::isDynamic(destShape[destPos])) {
5216  continue;
5217  }
5218  int64_t size = srcShape[srcPos];
5219  if (ShapedType::isDynamic(size))
5220  size = destShape[destPos];
5221  srcShape[srcPos] = size;
5222  destShape[destPos] = size;
5223  changeNeeded = true;
5224  }
5225  return changeNeeded;
5226 }
5227 
5228 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5229  PatternRewriter &rewriter) {
5230  /// unpack(pack(x)) -> x
5231  if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5232  if (packOp.getSourceType() != unPackOp.getDestType())
5233  return failure();
5234  if (packOp.getPaddingValue() ||
5235  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5236  !haveSameTiles(packOp, unPackOp))
5237  return failure();
5238  rewriter.replaceOp(unPackOp, packOp.getSource());
5239  return success();
5240  }
5241  /// unpack(destinationStyleOp(x)) -> unpack(x)
5242  if (auto dstStyleOp =
5243  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5244  auto destValue = cast<OpResult>(unPackOp.getDest());
5245  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5246  rewriter.modifyOpInPlace(unPackOp,
5247  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5248  return success();
5249  }
5250  /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
5251  if (unPackOp->hasOneUse()) {
5252  auto extractSliceUser =
5253  dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5254  if (extractSliceUser &&
5255  areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
5256  areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
5257  extractSliceUser.getSourceType().getRank() ==
5258  extractSliceUser.getResultType().getRank()) {
5259  OpBuilder::InsertionGuard g(rewriter);
5260  rewriter.setInsertionPoint(unPackOp);
5261  auto newDest = rewriter.create<tensor::ExtractSliceOp>(
5262  unPackOp->getLoc(), unPackOp.getDest(),
5263  extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5264  extractSliceUser.getMixedStrides());
5265  rewriter.modifyOpInPlace(unPackOp, [&]() {
5266  unPackOp.setDpsInitOperand(0, newDest);
5267  unPackOp.getResult().setType(newDest.getType());
5268  });
5269  rewriter.replaceOp(extractSliceUser, unPackOp);
5270  return success();
5271  }
5272  }
5273 
5274  // Insert tensor.cast ops if static shape inference is available..
5275  SmallVector<int64_t> srcShape, destShape;
5276  if (inferStaticShape(unPackOp, srcShape, destShape)) {
5277  Location loc = unPackOp.getLoc();
5278  Value source = unPackOp.getSource();
5279  if (srcShape != unPackOp.getSourceType().getShape()) {
5280  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5281  source = rewriter.create<tensor::CastOp>(loc, newSrcType,
5282  unPackOp.getSource());
5283  }
5284  Value dest = unPackOp.getDest();
5285  if (destShape != unPackOp.getDestType().getShape()) {
5286  auto newDestType = unPackOp.getDestType().clone(destShape);
5287  dest =
5288  rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
5289  }
5290  Value newOp = rewriter.create<UnPackOp>(
5291  loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5292  unPackOp.getOuterDimsPerm());
5293  rewriter.replaceOpWithNewOp<tensor::CastOp>(
5294  unPackOp, unPackOp.getResult().getType(), newOp);
5295  return success();
5296  }
5297 
5298  return failure();
5299 }
5300 
5301 bool UnPackOp::isLikeUnPad() {
5302  RankedTensorType packedTensorType = getSourceType();
5303  return isLikePadUnPad(*this, packedTensorType);
5304 }
5305 
5306 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
5307  if (OpFoldResult reshapedSource = reshapeConstantSource(
5308  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5309  getResult().getType()))
5310  return reshapedSource;
5311  return {};
5312 }
5313 
5314 /// Folds a tensor.cast op into a consuming UnPackOp op if the
5315 /// `tensor.cast` has source that is more static than the consuming op.
5316 ///
5317 /// Example:
5318 /// ```mlir
5319 /// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
5320 /// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
5321 /// ```
5322 ///
5323 /// folds into:
5324 ///
5325 /// ```mlir
5326 /// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
5327 /// ```
5328 struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
5330 
5331  LogicalResult matchAndRewrite(UnPackOp op,
5332  PatternRewriter &rewriter) const override {
5334  return failure();
5335 
5336  SmallVector<Type> newResultTypes(op->getResultTypes());
5337  SmallVector<Value> newOperands =
5339  Value sourceTensor = newOperands[0];
5340 
5341  // Get the updated mixed-tile-sizes attribute.
5342  SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
5343  rewriter, sourceTensor.getType(), op.getMixedTiles());
5344 
5345  // Clone op.
5346  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5347  // this point. However, in practice, we use them for things that we'd like
5348  // to preserve. Implement a better abstraction.
5349  UnPackOp newOp = rewriter.create<UnPackOp>(
5350  op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5351  newMixedTileSizes, op.getOuterDimsPerm());
5352  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5353 
5354  // Replace op.
5355  Value oldResult = op.getResult();
5356  Value newResult = newOp.getResult();
5357  Value replacement = (newResult.getType() != oldResult.getType())
5358  ? rewriter.create<tensor::CastOp>(
5359  op->getLoc(), oldResult.getType(), newResult)
5360  : newResult;
5361 
5362  rewriter.replaceOp(op, {replacement});
5363 
5364  return success();
5365  }
5366 };
5367 
5368 } // namespace linalg
5369 } // namespace mlir
5370 
5371 //===----------------------------------------------------------------------===//
5372 // LinalgDialect
5373 //===----------------------------------------------------------------------===//
5374 
5375 void LinalgDialect::getCanonicalizationPatterns(
5376  RewritePatternSet &results) const {
5377  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
5378  FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
5379 }
5380 
5382  Attribute value, Type type,
5383  Location loc) {
5384  return arith::ConstantOp::materialize(builder, value, type, loc);
5385 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
Definition: LinalgOps.cpp:3464
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1853
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:312
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2805
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2876
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
Definition: LinalgOps.cpp:3442
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4514
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:128
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2343
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:4513
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
Definition: LinalgOps.cpp:3457
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:300
static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp has exactly 3 result di...
Definition: LinalgOps.cpp:3521
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1686
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1734
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2850
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:190
static LogicalResult verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
Definition: LinalgOps.cpp:3544
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1486
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:162
TernaryFn ternaryFn
Definition: LinalgOps.cpp:4083
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1356
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:206
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2827
ElementwiseArityGroup arityGroup
Definition: LinalgOps.cpp:4077
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
Definition: LinalgOps.cpp:1247
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1214
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2236
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4512
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:338
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:989
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:331
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:61
union mlir::linalg::@1197::ArityGroupAndKind::Kind kind
UnaryFn unaryFn
Definition: LinalgOps.cpp:4081
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1410
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1515
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:368
static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
Definition: LinalgOps.cpp:3488
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
Definition: LinalgOps.cpp:375
BinaryFn binaryFn
Definition: LinalgOps.cpp:4082
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:226
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:316
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:968
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:730
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:433
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:445
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:413
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_iterator result_end()
Definition: Operation.h:414
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
Definition: Operation.h:523
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:666
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:578
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:191
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1168
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1218
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2646
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: LinalgOps.cpp:4336
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: LinalgOps.cpp:4882
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: LinalgOps.cpp:4974
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: LinalgOps.cpp:4675
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: LinalgOps.cpp:4392
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: LinalgOps.cpp:4407
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4379
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: LinalgOps.cpp:4690
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2335
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
Definition: LinalgOps.cpp:4092
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: LinalgOps.cpp:4520
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:108
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2376
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: LinalgOps.cpp:4868
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2315
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2326
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: LinalgOps.cpp:4348
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
Definition: LinalgOps.cpp:4304
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:4855
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:99
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:4841
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: LinalgOps.cpp:4810
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4363
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: LinalgOps.cpp:4422
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
Definition: LinalgOps.cpp:3646
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:358
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:351
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:367
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:69
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1297
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:1996
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1999
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:2021
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2024
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
Definition: LinalgOps.cpp:5037
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5040
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
Definition: LinalgOps.cpp:5328
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5331